API Updates (#73)

* API Keys passed from Client instead of distro configuration

* delete distribution registry

* Rename the "package" word away

* Introduce a "Router" layer for providers

Some providers need to be factorized and considered as thin routing
layers on top of other providers. Consider two examples:

- The inference API should be a routing layer over inference providers,
  routed using the "model" key
- The memory banks API is another instance where various memory bank
  types will be provided by independent providers (e.g., a vector store
  is served by Chroma while a keyvalue memory can be served by Redis or
  PGVector)

This commit introduces a generalized routing layer for this purpose.

* update `apis_to_serve`

* llama_toolchain -> llama_stack

* Codemod from llama_toolchain -> llama_stack

- added providers/registry
- cleaned up api/ subdirectories and moved impls away
- restructured api/api.py
- from llama_stack.apis.<api> import foo should work now
- update imports to do llama_stack.apis.<api>
- update many other imports
- added __init__, fixed some registry imports
- updated registry imports
- create_agentic_system -> create_agent
- AgenticSystem -> Agent

* Moved some stuff out of common/; re-generated OpenAPI spec

* llama-toolchain -> llama-stack (hyphens)

* add control plane API

* add redis adapter + sqlite provider

* move core -> distribution

* Some more toolchain -> stack changes

* small naming shenanigans

* Removing custom tool and agent utilities and moving them client side

* Move control plane to distribution server for now

* Remove control plane from API list

* no codeshield dependency randomly plzzzzz

* Add "fire" as a dependency

* add back event loggers

* stack configure fixes

* use brave instead of bing in the example client

* add init file so it gets packaged

* add init files so it gets packaged

* Update MANIFEST

* bug fix

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
Co-authored-by: Xi Yan <xiyan@meta.com>
Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
Ashwin Bharambe 2024-09-17 19:51:35 -07:00 committed by GitHub
parent f294eac5f5
commit 9487ad8294
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
213 changed files with 1725 additions and 1204 deletions

View file

@ -1,4 +1,5 @@
include requirements.txt include requirements.txt
include llama_toolchain/data/*.yaml include llama_stack/distribution/*.sh
include llama_toolchain/core/*.sh include llama_stack/cli/scripts/*.sh
include llama_toolchain/cli/scripts/*.sh include llama_stack/distribution/example_configs/conda/*.yaml
include llama_stack/distribution/example_configs/docker/*.yaml

View file

@ -1,6 +1,6 @@
# llama-stack # llama-stack
[![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-toolchain)](https://pypi.org/project/llama-toolchain/) [![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/)
[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/TZAAYNVtrU) [![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/TZAAYNVtrU)
This repository contains the specifications and implementations of the APIs which are part of the Llama Stack. This repository contains the specifications and implementations of the APIs which are part of the Llama Stack.
@ -42,7 +42,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c
## Installation ## Installation
You can install this repository as a [package](https://pypi.org/project/llama-toolchain/) with `pip install llama-toolchain` You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack`
If you want to install from source: If you want to install from source:

View file

@ -1,6 +1,6 @@
# Llama CLI Reference # Llama CLI Reference
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-toolchain` package. The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
### Subcommands ### Subcommands
1. `download`: `llama` cli tools supports downloading the model from Meta or HuggingFace. 1. `download`: `llama` cli tools supports downloading the model from Meta or HuggingFace.
@ -276,16 +276,16 @@ The following command and specifications allows you to get started with building
``` ```
llama stack build <path/to/config> llama stack build <path/to/config>
``` ```
- You will be required to pass in a file path to the build.config file (e.g. `./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_toolchain/configs/distributions/` folder. - You will be required to pass in a file path to the build.config file (e.g. `./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_stack/distribution/example_configs/` folder.
The file will be of the contents The file will be of the contents
``` ```
$ cat ./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml $ cat ./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml
name: 8b-instruct name: 8b-instruct
distribution_spec: distribution_spec:
distribution_type: local distribution_type: local
description: Use code from `llama_toolchain` itself to serve all llama stack APIs description: Use code from `llama_stack` itself to serve all llama stack APIs
docker_image: null docker_image: null
providers: providers:
inference: meta-reference inference: meta-reference
@ -311,7 +311,7 @@ After this step is complete, a file named `8b-instruct-build.yaml` will be gener
To specify a different API provider, we can change the `distribution_spec` in our `<name>-build.yaml` config. For example, the following build spec allows you to build a distribution using TGI as the inference API provider. To specify a different API provider, we can change the `distribution_spec` in our `<name>-build.yaml` config. For example, the following build spec allows you to build a distribution using TGI as the inference API provider.
``` ```
$ cat ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml $ cat ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml
name: local-tgi-conda-example name: local-tgi-conda-example
distribution_spec: distribution_spec:
@ -328,7 +328,7 @@ image_type: conda
The following command allows you to build a distribution with TGI as the inference API provider, with the name `tgi`. The following command allows you to build a distribution with TGI as the inference API provider, with the name `tgi`.
``` ```
llama stack build --config ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml --name tgi llama stack build --config ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml --name tgi
``` ```
We provide some example build configs to help you get started with building with different API providers. We provide some example build configs to help you get started with building with different API providers.
@ -337,11 +337,11 @@ We provide some example build configs to help you get started with building with
To build a docker image, simply change the `image_type` to `docker` in our `<name>-build.yaml` file, and run `llama stack build --config <name>-build.yaml`. To build a docker image, simply change the `image_type` to `docker` in our `<name>-build.yaml` file, and run `llama stack build --config <name>-build.yaml`.
``` ```
$ cat ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml $ cat ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml
name: local-docker-example name: local-docker-example
distribution_spec: distribution_spec:
description: Use code from `llama_toolchain` itself to serve all llama stack APIs description: Use code from `llama_stack` itself to serve all llama stack APIs
docker_image: null docker_image: null
providers: providers:
inference: meta-reference inference: meta-reference
@ -354,7 +354,7 @@ image_type: docker
The following command allows you to build a Docker image with the name `docker-local` The following command allows you to build a Docker image with the name `docker-local`
``` ```
llama stack build --config ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml --name docker-local llama stack build --config ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml --name docker-local
Dockerfile created successfully in /tmp/tmp.I0ifS2c46A/DockerfileFROM python:3.10-slim Dockerfile created successfully in /tmp/tmp.I0ifS2c46A/DockerfileFROM python:3.10-slim
WORKDIR /app WORKDIR /app
@ -480,9 +480,9 @@ This server is running a Llama model locally.
Once the server is setup, we can test it with a client to see the example outputs. Once the server is setup, we can test it with a client to see the example outputs.
``` ```
cd /path/to/llama-stack cd /path/to/llama-stack
conda activate <env> # any environment containing the llama-toolchain pip package will work conda activate <env> # any environment containing the llama-stack pip package will work
python -m llama_toolchain.inference.client localhost 5000 python -m llama_stack.apis.inference.client localhost 5000
``` ```
This will run the chat completion client and query the distributions /inference/chat_completion API. This will run the chat completion client and query the distributions /inference/chat_completion API.
@ -500,7 +500,7 @@ You know what's even more hilarious? People like you who think they can just Goo
Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by: Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by:
``` ```
python -m llama_toolchain.safety.client localhost 5000 python -m llama_stack.safety.client localhost 5000
``` ```
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo. You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.

View file

@ -1,6 +1,6 @@
# Getting Started # Getting Started
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-toolchain` package. The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
This guides allows you to quickly get started with building and running a Llama Stack server in < 5 minutes! This guides allows you to quickly get started with building and running a Llama Stack server in < 5 minutes!
@ -9,7 +9,7 @@ This guides allows you to quickly get started with building and running a Llama
**`llama stack build`** **`llama stack build`**
``` ```
llama stack build --config ./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml --name my-local-llama-stack llama stack build --config ./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml --name my-local-llama-stack
... ...
... ...
Build spec configuration saved at ~/.llama/distributions/conda/my-local-llama-stack-build.yaml Build spec configuration saved at ~/.llama/distributions/conda/my-local-llama-stack-build.yaml
@ -97,16 +97,16 @@ The following command and specifications allows you to get started with building
``` ```
llama stack build <path/to/config> llama stack build <path/to/config>
``` ```
- You will be required to pass in a file path to the build.config file (e.g. `./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_toolchain/configs/distributions/` folder. - You will be required to pass in a file path to the build.config file (e.g. `./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_stack/distribution/example_configs/` folder.
The file will be of the contents The file will be of the contents
``` ```
$ cat ./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml $ cat ./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml
name: 8b-instruct name: 8b-instruct
distribution_spec: distribution_spec:
distribution_type: local distribution_type: local
description: Use code from `llama_toolchain` itself to serve all llama stack APIs description: Use code from `llama_stack` itself to serve all llama stack APIs
docker_image: null docker_image: null
providers: providers:
inference: meta-reference inference: meta-reference
@ -132,7 +132,7 @@ After this step is complete, a file named `8b-instruct-build.yaml` will be gener
To specify a different API provider, we can change the `distribution_spec` in our `<name>-build.yaml` config. For example, the following build spec allows you to build a distribution using TGI as the inference API provider. To specify a different API provider, we can change the `distribution_spec` in our `<name>-build.yaml` config. For example, the following build spec allows you to build a distribution using TGI as the inference API provider.
``` ```
$ cat ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml $ cat ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml
name: local-tgi-conda-example name: local-tgi-conda-example
distribution_spec: distribution_spec:
@ -149,7 +149,7 @@ image_type: conda
The following command allows you to build a distribution with TGI as the inference API provider, with the name `tgi`. The following command allows you to build a distribution with TGI as the inference API provider, with the name `tgi`.
``` ```
llama stack build --config ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml --name tgi llama stack build --config ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml --name tgi
``` ```
We provide some example build configs to help you get started with building with different API providers. We provide some example build configs to help you get started with building with different API providers.
@ -158,11 +158,11 @@ We provide some example build configs to help you get started with building with
To build a docker image, simply change the `image_type` to `docker` in our `<name>-build.yaml` file, and run `llama stack build --config <name>-build.yaml`. To build a docker image, simply change the `image_type` to `docker` in our `<name>-build.yaml` file, and run `llama stack build --config <name>-build.yaml`.
``` ```
$ cat ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml $ cat ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml
name: local-docker-example name: local-docker-example
distribution_spec: distribution_spec:
description: Use code from `llama_toolchain` itself to serve all llama stack APIs description: Use code from `llama_stack` itself to serve all llama stack APIs
docker_image: null docker_image: null
providers: providers:
inference: meta-reference inference: meta-reference
@ -175,7 +175,7 @@ image_type: docker
The following command allows you to build a Docker image with the name `docker-local` The following command allows you to build a Docker image with the name `docker-local`
``` ```
llama stack build --config ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml --name docker-local llama stack build --config ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml --name docker-local
Dockerfile created successfully in /tmp/tmp.I0ifS2c46A/DockerfileFROM python:3.10-slim Dockerfile created successfully in /tmp/tmp.I0ifS2c46A/DockerfileFROM python:3.10-slim
WORKDIR /app WORKDIR /app
@ -294,9 +294,9 @@ This server is running a Llama model locally.
Once the server is setup, we can test it with a client to see the example outputs. Once the server is setup, we can test it with a client to see the example outputs.
``` ```
cd /path/to/llama-stack cd /path/to/llama-stack
conda activate <env> # any environment containing the llama-toolchain pip package will work conda activate <env> # any environment containing the llama-stack pip package will work
python -m llama_toolchain.inference.client localhost 5000 python -m llama_stack.apis.inference.client localhost 5000
``` ```
This will run the chat completion client and query the distributions /inference/chat_completion API. This will run the chat completion client and query the distributions /inference/chat_completion API.
@ -314,7 +314,7 @@ You know what's even more hilarious? People like you who think they can just Goo
Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by: Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by:
``` ```
python -m llama_toolchain.safety.client localhost 5000 python -m llama_stack.apis.safety.client localhost 5000
``` ```
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo. You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .agents import * # noqa: F401 F403

View file

@ -14,10 +14,10 @@ from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.common.deployment_types import * # noqa: F403 from llama_stack.apis.common.deployment_types import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
@json_schema_type @json_schema_type
@ -26,7 +26,7 @@ class Attachment(BaseModel):
mime_type: str mime_type: str
class AgenticSystemTool(Enum): class AgentTool(Enum):
brave_search = "brave_search" brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha" wolfram_alpha = "wolfram_alpha"
photogen = "photogen" photogen = "photogen"
@ -50,41 +50,35 @@ class SearchEngineType(Enum):
class SearchToolDefinition(ToolDefinitionCommon): class SearchToolDefinition(ToolDefinitionCommon):
# NOTE: brave_search is just a placeholder since model always uses # NOTE: brave_search is just a placeholder since model always uses
# brave_search as tool call name # brave_search as tool call name
type: Literal[AgenticSystemTool.brave_search.value] = ( type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
AgenticSystemTool.brave_search.value api_key: str
)
engine: SearchEngineType = SearchEngineType.brave engine: SearchEngineType = SearchEngineType.brave
remote_execution: Optional[RestAPIExecutionConfig] = None remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type @json_schema_type
class WolframAlphaToolDefinition(ToolDefinitionCommon): class WolframAlphaToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.wolfram_alpha.value] = ( type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
AgenticSystemTool.wolfram_alpha.value api_key: str
)
remote_execution: Optional[RestAPIExecutionConfig] = None remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type @json_schema_type
class PhotogenToolDefinition(ToolDefinitionCommon): class PhotogenToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
remote_execution: Optional[RestAPIExecutionConfig] = None remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type @json_schema_type
class CodeInterpreterToolDefinition(ToolDefinitionCommon): class CodeInterpreterToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.code_interpreter.value] = ( type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
AgenticSystemTool.code_interpreter.value
)
enable_inline_code_execution: bool = True enable_inline_code_execution: bool = True
remote_execution: Optional[RestAPIExecutionConfig] = None remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type @json_schema_type
class FunctionCallToolDefinition(ToolDefinitionCommon): class FunctionCallToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.function_call.value] = ( type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
AgenticSystemTool.function_call.value
)
function_name: str function_name: str
description: str description: str
parameters: Dict[str, ToolParamDefinition] parameters: Dict[str, ToolParamDefinition]
@ -95,30 +89,30 @@ class _MemoryBankConfigCommon(BaseModel):
bank_id: str bank_id: str
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon): class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon): class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
keys: List[str] # what keys to focus on keys: List[str] # what keys to focus on
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon): class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon): class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
entities: List[str] # what entities to focus on entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[ MemoryBankConfig = Annotated[
Union[ Union[
AgenticSystemVectorMemoryBankConfig, AgentVectorMemoryBankConfig,
AgenticSystemKeyValueMemoryBankConfig, AgentKeyValueMemoryBankConfig,
AgenticSystemKeywordMemoryBankConfig, AgentKeywordMemoryBankConfig,
AgenticSystemGraphMemoryBankConfig, AgentGraphMemoryBankConfig,
], ],
Field(discriminator="type"), Field(discriminator="type"),
] ]
@ -158,7 +152,7 @@ MemoryQueryGeneratorConfig = Annotated[
class MemoryToolDefinition(ToolDefinitionCommon): class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value type: Literal[AgentTool.memory.value] = AgentTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
# This config defines how a query is generated using the messages # This config defines how a query is generated using the messages
# for memory bank retrieval. # for memory bank retrieval.
@ -169,7 +163,7 @@ class MemoryToolDefinition(ToolDefinitionCommon):
max_chunks: int = 10 max_chunks: int = 10
AgenticSystemToolDefinition = Annotated[ AgentToolDefinition = Annotated[
Union[ Union[
SearchToolDefinition, SearchToolDefinition,
WolframAlphaToolDefinition, WolframAlphaToolDefinition,
@ -275,7 +269,7 @@ class AgentConfigCommon(BaseModel):
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list) tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json default=ToolPromptFormat.json
@ -292,7 +286,7 @@ class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None instructions: Optional[str] = None
class AgenticSystemTurnResponseEventType(Enum): class AgentTurnResponseEventType(Enum):
step_start = "step_start" step_start = "step_start"
step_complete = "step_complete" step_complete = "step_complete"
step_progress = "step_progress" step_progress = "step_progress"
@ -302,9 +296,9 @@ class AgenticSystemTurnResponseEventType(Enum):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseStepStartPayload(BaseModel): class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = ( event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
AgenticSystemTurnResponseEventType.step_start.value AgentTurnResponseEventType.step_start.value
) )
step_type: StepType step_type: StepType
step_id: str step_id: str
@ -312,20 +306,20 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseStepCompletePayload(BaseModel): class AgentTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = ( event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
AgenticSystemTurnResponseEventType.step_complete.value AgentTurnResponseEventType.step_complete.value
) )
step_type: StepType step_type: StepType
step_details: Step step_details: Step
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseStepProgressPayload(BaseModel): class AgentTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = ( event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
AgenticSystemTurnResponseEventType.step_progress.value AgentTurnResponseEventType.step_progress.value
) )
step_type: StepType step_type: StepType
step_id: str step_id: str
@ -336,49 +330,49 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseTurnStartPayload(BaseModel): class AgentTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = ( event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
AgenticSystemTurnResponseEventType.turn_start.value AgentTurnResponseEventType.turn_start.value
) )
turn_id: str turn_id: str
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel): class AgentTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = ( event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
AgenticSystemTurnResponseEventType.turn_complete.value AgentTurnResponseEventType.turn_complete.value
) )
turn: Turn turn: Turn
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseEvent(BaseModel): class AgentTurnResponseEvent(BaseModel):
"""Streamed agent execution response.""" """Streamed agent execution response."""
payload: Annotated[ payload: Annotated[
Union[ Union[
AgenticSystemTurnResponseStepStartPayload, AgentTurnResponseStepStartPayload,
AgenticSystemTurnResponseStepProgressPayload, AgentTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepCompletePayload, AgentTurnResponseStepCompletePayload,
AgenticSystemTurnResponseTurnStartPayload, AgentTurnResponseTurnStartPayload,
AgenticSystemTurnResponseTurnCompletePayload, AgentTurnResponseTurnCompletePayload,
], ],
Field(discriminator="event_type"), Field(discriminator="event_type"),
] ]
@json_schema_type @json_schema_type
class AgenticSystemCreateResponse(BaseModel): class AgentCreateResponse(BaseModel):
agent_id: str agent_id: str
@json_schema_type @json_schema_type
class AgenticSystemSessionCreateResponse(BaseModel): class AgentSessionCreateResponse(BaseModel):
session_id: str session_id: str
@json_schema_type @json_schema_type
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn): class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
agent_id: str agent_id: str
session_id: str session_id: str
@ -397,24 +391,24 @@ class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseStreamChunk(BaseModel): class AgentTurnResponseStreamChunk(BaseModel):
event: AgenticSystemTurnResponseEvent event: AgentTurnResponseEvent
@json_schema_type @json_schema_type
class AgenticSystemStepResponse(BaseModel): class AgentStepResponse(BaseModel):
step: Step step: Step
class AgenticSystem(Protocol): class Agents(Protocol):
@webmethod(route="/agentic_system/create") @webmethod(route="/agents/create")
async def create_agentic_system( async def create_agent(
self, self,
agent_config: AgentConfig, agent_config: AgentConfig,
) -> AgenticSystemCreateResponse: ... ) -> AgentCreateResponse: ...
@webmethod(route="/agentic_system/turn/create") @webmethod(route="/agents/turn/create")
async def create_agentic_system_turn( async def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,
@ -426,42 +420,40 @@ class AgenticSystem(Protocol):
], ],
attachments: Optional[List[Attachment]] = None, attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> AgenticSystemTurnResponseStreamChunk: ... ) -> AgentTurnResponseStreamChunk: ...
@webmethod(route="/agentic_system/turn/get") @webmethod(route="/agents/turn/get")
async def get_agentic_system_turn( async def get_agents_turn(
self, self,
agent_id: str, agent_id: str,
turn_id: str, turn_id: str,
) -> Turn: ... ) -> Turn: ...
@webmethod(route="/agentic_system/step/get") @webmethod(route="/agents/step/get")
async def get_agentic_system_step( async def get_agents_step(
self, agent_id: str, turn_id: str, step_id: str self, agent_id: str, turn_id: str, step_id: str
) -> AgenticSystemStepResponse: ... ) -> AgentStepResponse: ...
@webmethod(route="/agentic_system/session/create") @webmethod(route="/agents/session/create")
async def create_agentic_system_session( async def create_agent_session(
self, self,
agent_id: str, agent_id: str,
session_name: str, session_name: str,
) -> AgenticSystemSessionCreateResponse: ... ) -> AgentSessionCreateResponse: ...
@webmethod(route="/agentic_system/session/get") @webmethod(route="/agents/session/get")
async def get_agentic_system_session( async def get_agents_session(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_ids: Optional[List[str]] = None, turn_ids: Optional[List[str]] = None,
) -> Session: ... ) -> Session: ...
@webmethod(route="/agentic_system/session/delete") @webmethod(route="/agents/session/delete")
async def delete_agentic_system_session( async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ...
self, agent_id: str, session_id: str
) -> None: ...
@webmethod(route="/agentic_system/delete") @webmethod(route="/agents/delete")
async def delete_agentic_system( async def delete_agents(
self, self,
agent_id: str, agent_id: str,
) -> None: ... ) -> None: ...

View file

@ -6,56 +6,58 @@
import asyncio import asyncio
import json import json
import os
from typing import AsyncGenerator from typing import AsyncGenerator
import fire import fire
import httpx import httpx
from dotenv import load_dotenv
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.core.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403 from .agents import * # noqa: F403
from .event_logger import EventLogger from .event_logger import EventLogger
load_dotenv()
async def get_client_impl(config: RemoteProviderConfig, _deps): async def get_client_impl(config: RemoteProviderConfig, _deps):
return AgenticSystemClient(config.url) return AgentsClient(config.url)
def encodable_dict(d: BaseModel): def encodable_dict(d: BaseModel):
return json.loads(d.json()) return json.loads(d.json())
class AgenticSystemClient(AgenticSystem): class AgentsClient(Agents):
def __init__(self, base_url: str): def __init__(self, base_url: str):
self.base_url = base_url self.base_url = base_url
async def create_agentic_system( async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse:
self, agent_config: AgentConfig
) -> AgenticSystemCreateResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/agentic_system/create", f"{self.base_url}/agents/create",
json={ json={
"agent_config": encodable_dict(agent_config), "agent_config": encodable_dict(agent_config),
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
return AgenticSystemCreateResponse(**response.json()) return AgentCreateResponse(**response.json())
async def create_agentic_system_session( async def create_agent_session(
self, self,
agent_id: str, agent_id: str,
session_name: str, session_name: str,
) -> AgenticSystemSessionCreateResponse: ) -> AgentSessionCreateResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/agentic_system/session/create", f"{self.base_url}/agents/session/create",
json={ json={
"agent_id": agent_id, "agent_id": agent_id,
"session_name": session_name, "session_name": session_name,
@ -63,16 +65,16 @@ class AgenticSystemClient(AgenticSystem):
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
return AgenticSystemSessionCreateResponse(**response.json()) return AgentSessionCreateResponse(**response.json())
async def create_agentic_system_turn( async def create_agent_turn(
self, self,
request: AgenticSystemTurnCreateRequest, request: AgentTurnCreateRequest,
) -> AsyncGenerator: ) -> AsyncGenerator:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
async with client.stream( async with client.stream(
"POST", "POST",
f"{self.base_url}/agentic_system/turn/create", f"{self.base_url}/agents/turn/create",
json=encodable_dict(request), json=encodable_dict(request),
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
@ -86,7 +88,7 @@ class AgenticSystemClient(AgenticSystem):
cprint(data, "red") cprint(data, "red")
continue continue
yield AgenticSystemTurnResponseStreamChunk(**jdata) yield AgentTurnResponseStreamChunk(**jdata)
except Exception as e: except Exception as e:
print(data) print(data)
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")
@ -102,16 +104,16 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
tool_prompt_format=ToolPromptFormat.function_tag, tool_prompt_format=ToolPromptFormat.function_tag,
) )
create_response = await api.create_agentic_system(agent_config) create_response = await api.create_agent(agent_config)
session_response = await api.create_agentic_system_session( session_response = await api.create_agent_session(
agent_id=create_response.agent_id, agent_id=create_response.agent_id,
session_name="test_session", session_name="test_session",
) )
for content in user_prompts: for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"]) cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agentic_system_turn( iterator = api.create_agent_turn(
AgenticSystemTurnCreateRequest( AgentTurnCreateRequest(
agent_id=create_response.agent_id, agent_id=create_response.agent_id,
session_id=session_response.session_id, session_id=session_response.session_id,
messages=[ messages=[
@ -128,11 +130,14 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
async def run_main(host: str, port: int): async def run_main(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}") api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [ tool_definitions = [
SearchToolDefinition(engine=SearchEngineType.bing), SearchToolDefinition(
WolframAlphaToolDefinition(), engine=SearchEngineType.brave,
api_key=os.getenv("BRAVE_SEARCH_API_KEY"),
),
WolframAlphaToolDefinition(api_key=os.getenv("WOLFRAM_ALPHA_API_KEY")),
CodeInterpreterToolDefinition(), CodeInterpreterToolDefinition(),
] ]
tool_definitions += [ tool_definitions += [
@ -165,7 +170,7 @@ async def run_main(host: str, port: int):
async def run_rag(host: str, port: int): async def run_rag(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}") api = AgentsClient(f"http://{host}:{port}")
urls = [ urls = [
"memory_optimizations.rst", "memory_optimizations.rst",
@ -186,7 +191,7 @@ async def run_rag(host: str, port: int):
] ]
# Alternatively, you can pre-populate the memory bank with documents for example, # Alternatively, you can pre-populate the memory bank with documents for example,
# using `llama_toolchain.memory.client`. Then you can grab the bank_id # using `llama_stack.memory.client`. Then you can grab the bank_id
# from the output of that run. # from the output of that run.
tool_definitions = [ tool_definitions = [
MemoryToolDefinition( MemoryToolDefinition(

View file

@ -9,12 +9,9 @@ from typing import Optional
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tool_utils import ToolUtils from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_toolchain.agentic_system.api import ( from termcolor import cprint
AgenticSystemTurnResponseEventType,
StepType,
)
class LogEvent: class LogEvent:
@ -40,7 +37,7 @@ class LogEvent:
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush) cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
EventType = AgenticSystemTurnResponseEventType EventType = AgentTurnResponseEventType
class EventLogger: class EventLogger:

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .batch_inference import * # noqa: F401 F403

View file

@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
@json_schema_type @json_schema_type

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .dataset import * # noqa: F401 F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .evals import * # noqa: F401 F403

View file

@ -12,8 +12,8 @@ from llama_models.schema_utils import webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.dataset.api import * # noqa: F403 from llama_stack.apis.dataset import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403
class TextGenerationMetric(Enum): class TextGenerationMetric(Enum):

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inference import * # noqa: F401 F403

View file

@ -11,11 +11,13 @@ from typing import Any, AsyncGenerator
import fire import fire
import httpx import httpx
from llama_toolchain.core.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from .api import ( from .event_logger import EventLogger
from .inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
@ -23,7 +25,6 @@ from .api import (
Inference, Inference,
UserMessage, UserMessage,
) )
from .event_logger import EventLogger
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference: async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_toolchain.inference.api import ( from llama_stack.apis.inference import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
) )

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .memory import * # noqa: F401 F403

View file

@ -6,17 +6,18 @@
import asyncio import asyncio
import json import json
import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import fire import fire
import httpx import httpx
from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint from termcolor import cprint
from llama_toolchain.core.datatypes import RemoteProviderConfig from .memory import * # noqa: F403
from .api import * # noqa: F403
from .common.file_utils import data_url_from_file from .common.file_utils import data_url_from_file

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .models import * # noqa: F401 F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .post_training import * # noqa: F401 F403

View file

@ -14,8 +14,8 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.dataset.api import * # noqa: F403 from llama_stack.apis.dataset import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403
class OptimizerType(Enum): class OptimizerType(Enum):

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .reward_scoring import * # noqa: F401 F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .safety import * # noqa: F401 F403

View file

@ -14,11 +14,11 @@ import httpx
from llama_models.llama3.api.datatypes import UserMessage from llama_models.llama3.api.datatypes import UserMessage
from llama_toolchain.core.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from .api import * # noqa: F403 from .safety import * # noqa: F403
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety: async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:

View file

@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
@json_schema_type @json_schema_type

34
llama_stack/apis/stack.py Normal file
View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.batch_inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.apis.reward_scoring import * # noqa: F403
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
class LlamaStack(
Inference,
BatchInference,
Agents,
RewardScoring,
Safety,
SyntheticDataGeneration,
Datasets,
Telemetry,
PostTraining,
Memory,
Evaluations,
):
pass

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .synthetic_data_generation import * # noqa: F401 F403

View file

@ -13,7 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.reward_scoring.api import * # noqa: F403 from llama_stack.apis.reward_scoring import * # noqa: F403
class FilteringFunction(Enum): class FilteringFunction(Enum):

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .telemetry import * # noqa: F401 F403

View file

@ -20,7 +20,7 @@ from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
class Download(Subcommand): class Download(Subcommand):
@ -92,7 +92,7 @@ def _hf_download(
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from llama_toolchain.common.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
repo_id = model.huggingface_repo repo_id = model.huggingface_repo
if repo_id is None: if repo_id is None:
@ -106,7 +106,7 @@ def _hf_download(
local_dir=output_dir, local_dir=output_dir,
ignore_patterns=ignore_patterns, ignore_patterns=ignore_patterns,
token=hf_token, token=hf_token,
library_name="llama-toolchain", library_name="llama-stack",
) )
except GatedRepoError: except GatedRepoError:
parser.error( parser.error(
@ -126,7 +126,7 @@ def _hf_download(
def _meta_download(model: "Model", meta_url: str): def _meta_download(model: "Model", meta_url: str):
from llama_models.sku_list import llama_meta_net_info from llama_models.sku_list import llama_meta_net_info
from llama_toolchain.common.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
output_dir = Path(model_local_dir(model.descriptor())) output_dir = Path(model_local_dir(model.descriptor()))
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@ -188,7 +188,7 @@ class Manifest(BaseModel):
def _download_from_manifest(manifest_file: str): def _download_from_manifest(manifest_file: str):
from llama_toolchain.common.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
with open(manifest_file, "r") as f: with open(manifest_file, "r") as f:
d = json.load(f) d = json.load(f)

View file

@ -31,16 +31,6 @@ class LlamaCLIParser:
ModelParser.create(subparsers) ModelParser.create(subparsers)
StackParser.create(subparsers) StackParser.create(subparsers)
# Import sub-commands from agentic_system if they exist
try:
from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES
for module in SUBCOMMAND_MODULES:
module.create(subparsers)
except ImportError:
pass
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args() return self.parser.parse_args()

View file

@ -9,12 +9,12 @@ import json
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table
from llama_toolchain.common.serialize import EnumEncoder
from termcolor import colored from termcolor import colored
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.distribution.utils.serialize import EnumEncoder
class ModelDescribe(Subcommand): class ModelDescribe(Subcommand):
"""Show details about a model""" """Show details about a model"""

View file

@ -6,7 +6,7 @@
import argparse import argparse
from llama_toolchain.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
class ModelDownload(Subcommand): class ModelDownload(Subcommand):
@ -19,6 +19,6 @@ class ModelDownload(Subcommand):
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
) )
from llama_toolchain.cli.download import setup_download_parser from llama_stack.cli.download import setup_download_parser
setup_download_parser(self.parser) setup_download_parser(self.parser)

View file

@ -8,8 +8,8 @@ import argparse
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_toolchain.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table from llama_stack.cli.table import print_table
class ModelList(Subcommand): class ModelList(Subcommand):

View file

@ -6,12 +6,12 @@
import argparse import argparse
from llama_toolchain.cli.model.describe import ModelDescribe from llama_stack.cli.model.describe import ModelDescribe
from llama_toolchain.cli.model.download import ModelDownload from llama_stack.cli.model.download import ModelDownload
from llama_toolchain.cli.model.list import ModelList from llama_stack.cli.model.list import ModelList
from llama_toolchain.cli.model.template import ModelTemplate from llama_stack.cli.model.template import ModelTemplate
from llama_toolchain.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
class ModelParser(Subcommand): class ModelParser(Subcommand):

View file

@ -9,7 +9,7 @@ import textwrap
from termcolor import colored from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
class ModelTemplate(Subcommand): class ModelTemplate(Subcommand):
@ -75,7 +75,7 @@ class ModelTemplate(Subcommand):
render_jinja_template, render_jinja_template,
) )
from llama_toolchain.cli.table import print_table from llama_stack.cli.table import print_table
if args.name: if args.name:
tool_prompt_format = self._prompt_type(args.format) tool_prompt_format = self._prompt_type(args.format)

View file

@ -6,8 +6,8 @@
import argparse import argparse
from llama_toolchain.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_toolchain.core.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from pathlib import Path from pathlib import Path
import yaml import yaml
@ -29,7 +29,7 @@ class StackBuild(Subcommand):
self.parser.add_argument( self.parser.add_argument(
"config", "config",
type=str, type=str,
help="Path to a config file to use for the build. You may find example configs in llama_toolchain/configs/distributions", help="Path to a config file to use for the build. You may find example configs in llama_stack/distribution/example_configs",
) )
self.parser.add_argument( self.parser.add_argument(
@ -44,17 +44,17 @@ class StackBuild(Subcommand):
import json import json
import os import os
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_toolchain.common.serialize import EnumEncoder from llama_stack.distribution.utils.serialize import EnumEncoder
from llama_toolchain.core.package import ApiInput, build_package, ImageType from llama_stack.distribution.build import ApiInput, build_image, ImageType
from termcolor import cprint from termcolor import cprint
# save build.yaml spec for building same distribution again # save build.yaml spec for building same distribution again
if build_config.image_type == ImageType.docker.value: if build_config.image_type == ImageType.docker.value:
# docker needs build file to be in the llama-stack repo dir to be able to copy over to the image # docker needs build file to be in the llama-stack repo dir to be able to copy over to the image
llama_toolchain_path = Path(os.path.relpath(__file__)).parent.parent.parent llama_stack_path = Path(os.path.relpath(__file__)).parent.parent.parent
build_dir = ( build_dir = (
llama_toolchain_path / "configs/distributions" / build_config.image_type llama_stack_path / "configs/distributions" / build_config.image_type
) )
else: else:
build_dir = DISTRIBS_BASE_DIR / build_config.image_type build_dir = DISTRIBS_BASE_DIR / build_config.image_type
@ -66,7 +66,7 @@ class StackBuild(Subcommand):
to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder)) to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder))
f.write(yaml.dump(to_write, sort_keys=False)) f.write(yaml.dump(to_write, sort_keys=False))
build_package(build_config, build_file_path) build_image(build_config, build_file_path)
cprint( cprint(
f"Build spec configuration saved at {str(build_file_path)}", f"Build spec configuration saved at {str(build_file_path)}",
@ -74,12 +74,12 @@ class StackBuild(Subcommand):
) )
def _run_stack_build_command(self, args: argparse.Namespace) -> None: def _run_stack_build_command(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.prompt_for_config import prompt_for_config from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_toolchain.core.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
if not args.config: if not args.config:
self.parser.error( self.parser.error(
"No config file specified. Please use `llama stack build /path/to/*-build.yaml`. Example config files can be found in llama_toolchain/configs/distributions" "No config file specified. Please use `llama stack build /path/to/*-build.yaml`. Example config files can be found in llama_stack/distribution/example_configs"
) )
return return

View file

@ -13,11 +13,11 @@ import pkg_resources
import yaml import yaml
from termcolor import cprint from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.common.exec import run_with_pty from llama_stack.distribution.utils.exec import run_with_pty
from llama_toolchain.core.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
import os import os
@ -49,7 +49,7 @@ class StackConfigure(Subcommand):
) )
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None: def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.core.package import ImageType from llama_stack.distribution.build import ImageType
docker_image = None docker_image = None
build_config_file = Path(args.config) build_config_file = Path(args.config)
@ -66,7 +66,7 @@ class StackConfigure(Subcommand):
os.makedirs(builds_dir, exist_ok=True) os.makedirs(builds_dir, exist_ok=True)
script = pkg_resources.resource_filename( script = pkg_resources.resource_filename(
"llama_toolchain", "core/configure_container.sh" "llama_stack", "distribution/configure_container.sh"
) )
script_args = [script, docker_image, str(builds_dir)] script_args = [script, docker_image, str(builds_dir)]
@ -95,8 +95,8 @@ class StackConfigure(Subcommand):
build_config: BuildConfig, build_config: BuildConfig,
output_dir: Optional[str] = None, output_dir: Optional[str] = None,
): ):
from llama_toolchain.common.serialize import EnumEncoder from llama_stack.distribution.configure import configure_api_providers
from llama_toolchain.core.configure import configure_api_providers from llama_stack.distribution.utils.serialize import EnumEncoder
builds_dir = BUILDS_BASE_DIR / build_config.image_type builds_dir = BUILDS_BASE_DIR / build_config.image_type
if output_dir: if output_dir:
@ -105,16 +105,9 @@ class StackConfigure(Subcommand):
image_name = build_config.name.replace("::", "-") image_name = build_config.name.replace("::", "-")
run_config_file = builds_dir / f"{image_name}-run.yaml" run_config_file = builds_dir / f"{image_name}-run.yaml"
api2providers = build_config.distribution_spec.providers
stub_config = {
api_str: {"provider_id": provider}
for api_str, provider in api2providers.items()
}
if run_config_file.exists(): if run_config_file.exists():
cprint( cprint(
f"Configuration already exists for {build_config.name}. Will overwrite...", f"Configuration already exists at `{str(run_config_file)}`. Will overwrite...",
"yellow", "yellow",
attrs=["bold"], attrs=["bold"],
) )
@ -123,10 +116,12 @@ class StackConfigure(Subcommand):
config = StackRunConfig( config = StackRunConfig(
built_at=datetime.now(), built_at=datetime.now(),
image_name=image_name, image_name=image_name,
providers=stub_config, apis_to_serve=[],
provider_map={},
) )
config.providers = configure_api_providers(config.providers) config = configure_api_providers(config, build_config.distribution_spec)
config.docker_image = ( config.docker_image = (
image_name if build_config.image_type == "docker" else None image_name if build_config.image_type == "docker" else None
) )

View file

@ -6,7 +6,7 @@
import argparse import argparse
from llama_toolchain.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
class StackListApis(Subcommand): class StackListApis(Subcommand):
@ -25,8 +25,8 @@ class StackListApis(Subcommand):
pass pass
def _run_apis_list_cmd(self, args: argparse.Namespace) -> None: def _run_apis_list_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.cli.table import print_table from llama_stack.cli.table import print_table
from llama_toolchain.core.distribution import stack_apis from llama_stack.distribution.distribution import stack_apis
# eventually, this should query a registry at llama.meta.com/llamastack/distributions # eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [ headers = [

View file

@ -6,7 +6,7 @@
import argparse import argparse
from llama_toolchain.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
class StackListProviders(Subcommand): class StackListProviders(Subcommand):
@ -22,7 +22,7 @@ class StackListProviders(Subcommand):
self.parser.set_defaults(func=self._run_providers_list_cmd) self.parser.set_defaults(func=self._run_providers_list_cmd)
def _add_arguments(self): def _add_arguments(self):
from llama_toolchain.core.distribution import stack_apis from llama_stack.distribution.distribution import stack_apis
api_values = [a.value for a in stack_apis()] api_values = [a.value for a in stack_apis()]
self.parser.add_argument( self.parser.add_argument(
@ -33,8 +33,8 @@ class StackListProviders(Subcommand):
) )
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None: def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.cli.table import print_table from llama_stack.cli.table import print_table
from llama_toolchain.core.distribution import Api, api_providers from llama_stack.distribution.distribution import Api, api_providers
all_providers = api_providers() all_providers = api_providers()
providers_for_api = all_providers[Api(args.api)] providers_for_api = all_providers[Api(args.api)]

View file

@ -11,8 +11,8 @@ from pathlib import Path
import pkg_resources import pkg_resources
import yaml import yaml
from llama_toolchain.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_toolchain.core.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
class StackRun(Subcommand): class StackRun(Subcommand):
@ -47,7 +47,7 @@ class StackRun(Subcommand):
) )
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty from llama_stack.distribution.utils.exec import run_with_pty
if not args.config: if not args.config:
self.parser.error("Must specify a config file to run") self.parser.error("Must specify a config file to run")
@ -67,14 +67,14 @@ class StackRun(Subcommand):
if config.docker_image: if config.docker_image:
script = pkg_resources.resource_filename( script = pkg_resources.resource_filename(
"llama_toolchain", "llama_stack",
"core/start_container.sh", "distribution/start_container.sh",
) )
run_args = [script, config.docker_image] run_args = [script, config.docker_image]
else: else:
script = pkg_resources.resource_filename( script = pkg_resources.resource_filename(
"llama_toolchain", "llama_stack",
"core/start_conda_env.sh", "distribution/start_conda_env.sh",
) )
run_args = [ run_args = [
script, script,

View file

@ -6,7 +6,7 @@
import argparse import argparse
from llama_toolchain.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from .build import StackBuild from .build import StackBuild
from .configure import StackConfigure from .configure import StackConfigure

View file

@ -4,26 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import os
from datetime import datetime
from enum import Enum from enum import Enum
from typing import List, Optional from typing import List, Optional
import pkg_resources import pkg_resources
import yaml
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.common.serialize import EnumEncoder
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_toolchain.core.datatypes import * # noqa: F403 from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.datatypes import * # noqa: F403
from pathlib import Path from pathlib import Path
from llama_toolchain.core.distribution import api_providers, SERVER_DEPENDENCIES from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES
class ImageType(Enum): class ImageType(Enum):
@ -41,7 +35,7 @@ class ApiInput(BaseModel):
provider: str provider: str
def build_package(build_config: BuildConfig, build_file_path: Path): def build_image(build_config: BuildConfig, build_file_path: Path):
package_deps = Dependencies( package_deps = Dependencies(
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim", docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
pip_packages=SERVER_DEPENDENCIES, pip_packages=SERVER_DEPENDENCIES,
@ -49,21 +43,32 @@ def build_package(build_config: BuildConfig, build_file_path: Path):
# extend package dependencies based on providers spec # extend package dependencies based on providers spec
all_providers = api_providers() all_providers = api_providers()
for api_str, provider in build_config.distribution_spec.providers.items(): for (
api_str,
provider_or_providers,
) in build_config.distribution_spec.providers.items():
providers_for_api = all_providers[Api(api_str)] providers_for_api = all_providers[Api(api_str)]
if provider not in providers_for_api:
raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`"
)
provider_spec = providers_for_api[provider] providers = (
package_deps.pip_packages.extend(provider_spec.pip_packages) provider_or_providers
if provider_spec.docker_image: if isinstance(provider_or_providers, list)
raise ValueError("A stack's dependencies cannot have a docker image") else [provider_or_providers]
)
for provider in providers:
if provider not in providers_for_api:
raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`"
)
provider_spec = providers_for_api[provider]
package_deps.pip_packages.extend(provider_spec.pip_packages)
if provider_spec.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
if build_config.image_type == ImageType.docker.value: if build_config.image_type == ImageType.docker.value:
script = pkg_resources.resource_filename( script = pkg_resources.resource_filename(
"llama_toolchain", "core/build_container.sh" "llama_stack", "distribution/build_container.sh"
) )
args = [ args = [
script, script,
@ -74,7 +79,7 @@ def build_package(build_config: BuildConfig, build_file_path: Path):
] ]
else: else:
script = pkg_resources.resource_filename( script = pkg_resources.resource_filename(
"llama_toolchain", "core/build_conda_env.sh" "llama_stack", "distribution/build_conda_env.sh"
) )
args = [ args = [
script, script,

View file

@ -7,11 +7,11 @@
# the root directory of this source tree. # the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-} LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-} LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR" echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_MODELS_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR" echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
@ -78,19 +78,19 @@ ensure_conda_env_python310() {
if [ -n "$TEST_PYPI_VERSION" ]; then if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first # these packages are damaged in test-pypi, so install them first
pip install fastapi libcst pip install fastapi libcst
pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-toolchain==$TEST_PYPI_VERSION $pip_dependencies pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION $pip_dependencies
else else
# Re-installing llama-toolchain in the new conda environment # Re-installing llama-stack in the new conda environment
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then if [ ! -d "$LLAMA_STACK_DIR" ]; then
printf "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}\n" >&2 printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
exit 1 exit 1
fi fi
printf "Installing from LLAMA_TOOLCHAIN_DIR: $LLAMA_TOOLCHAIN_DIR\n" printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
pip install --no-cache-dir -e "$LLAMA_TOOLCHAIN_DIR" pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
else else
pip install --no-cache-dir llama-toolchain pip install --no-cache-dir llama-stack
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_MODELS_DIR" ]; then

View file

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-} LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-} LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ "$#" -ne 4 ]; then if [ "$#" -ne 4 ]; then
@ -55,17 +55,17 @@ RUN apt-get update && apt-get install -y \
EOF EOF
toolchain_mount="/app/llama-toolchain-source" stack_mount="/app/llama-stack-source"
models_mount="/app/llama-models-source" models_mount="/app/llama-models-source"
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then if [ ! -d "$LLAMA_STACK_DIR" ]; then
echo "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}" >&2 echo "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}" >&2
exit 1 exit 1
fi fi
add_to_docker "RUN pip install $toolchain_mount" add_to_docker "RUN pip install $stack_mount"
else else
add_to_docker "RUN pip install llama-toolchain" add_to_docker "RUN pip install llama-stack"
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_MODELS_DIR" ]; then
@ -90,7 +90,7 @@ add_to_docker <<EOF
# This would be good in production but for debugging flexibility lets not add it right now # This would be good in production but for debugging flexibility lets not add it right now
# We need a more solid production ready entrypoint.sh anyway # We need a more solid production ready entrypoint.sh anyway
# #
# ENTRYPOINT ["python", "-m", "llama_toolchain.core.server"] # ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
EOF EOF
@ -101,8 +101,8 @@ cat $TEMP_DIR/Dockerfile
printf "\n" printf "\n"
mounts="" mounts=""
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_TOOLCHAIN_DIR):$toolchain_mount" mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount"
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount" mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"

View file

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
from llama_stack.distribution.datatypes import * # noqa: F403
from termcolor import cprint
from llama_stack.distribution.distribution import api_providers, stack_apis
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
# These are hacks so we can re-use the `prompt_for_config` utility
# This needs a bunch of work to be made very user friendly.
class ReqApis(BaseModel):
apis_to_serve: List[str]
def make_routing_entry_type(config_class: Any):
class BaseModelWithConfig(BaseModel):
routing_key: str
config: config_class
return BaseModelWithConfig
# TODO: make sure we can deal with existing configuration values correctly
# instead of just overwriting them
def configure_api_providers(
config: StackRunConfig, spec: DistributionSpec
) -> StackRunConfig:
cprint("Configuring APIs to serve...", "white", attrs=["bold"])
print("Enter comma-separated list of APIs to serve:")
apis = config.apis_to_serve or list(spec.providers.keys())
apis = [a for a in apis if a != "telemetry"]
req_apis = ReqApis(
apis_to_serve=apis,
)
req_apis = prompt_for_config(ReqApis, req_apis)
config.apis_to_serve = req_apis.apis_to_serve
print("")
apis = [v.value for v in stack_apis()]
all_providers = api_providers()
for api_str in spec.providers.keys():
if api_str not in apis:
raise ValueError(f"Unknown API `{api_str}`")
cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"])
api = Api(api_str)
provider_or_providers = spec.providers[api_str]
if isinstance(provider_or_providers, list) and len(provider_or_providers) > 1:
print(
"You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n"
)
routing_entries = []
for p in provider_or_providers:
print(f"Configuring provider `{p}`...")
provider_spec = all_providers[api][p]
config_type = instantiate_class_type(provider_spec.config_class)
# TODO: we need to validate the routing keys, and
# perhaps it is better if we break this out into asking
# for a routing key separately from the associated config
wrapper_type = make_routing_entry_type(config_type)
rt_entry = prompt_for_config(wrapper_type, None)
routing_entries.append(
ProviderRoutingEntry(
provider_id=p,
routing_key=rt_entry.routing_key,
config=rt_entry.config.dict(),
)
)
config.provider_map[api_str] = routing_entries
else:
p = (
provider_or_providers[0]
if isinstance(provider_or_providers, list)
else provider_or_providers
)
print(f"Configuring provider `{p}`...")
provider_spec = all_providers[api][p]
config_type = instantiate_class_type(provider_spec.config_class)
try:
provider_config = config.provider_map.get(api_str)
if provider_config:
existing = config_type(**provider_config.config)
else:
existing = None
except Exception:
existing = None
cfg = prompt_for_config(config_type, existing)
config.provider_map[api_str] = GenericProviderConfig(
provider_id=p,
config=cfg.dict(),
)
return config

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .control_plane import * # noqa: F401 F403

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import RedisImplConfig
async def get_adapter_impl(config: RedisImplConfig, _deps):
from .redis import RedisControlPlaneAdapter
impl = RedisControlPlaneAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class RedisImplConfig(BaseModel):
url: str = Field(
description="The URL for the Redis server",
)
namespace: Optional[str] = Field(
default=None,
description="All keys will be prefixed with this namespace",
)

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime, timedelta
from typing import Any, List, Optional
from redis.asyncio import Redis
from llama_stack.apis.control_plane import * # noqa: F403
from .config import RedisImplConfig
class RedisControlPlaneAdapter(ControlPlane):
def __init__(self, config: RedisImplConfig):
self.config = config
async def initialize(self) -> None:
self.redis = Redis.from_url(self.config.url)
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
return f"{self.config.namespace}:{key}"
async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None
) -> None:
key = self._namespaced_key(key)
await self.redis.set(key, value)
if expiration:
await self.redis.expireat(key, expiration)
async def get(self, key: str) -> Optional[ControlPlaneValue]:
key = self._namespaced_key(key)
value = await self.redis.get(key)
if value is None:
return None
ttl = await self.redis.ttl(key)
expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None
return ControlPlaneValue(key=key, value=value, expiration=expiration)
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
await self.redis.delete(key)
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
keys = await self.redis.keys(f"{start_key}*")
result = []
for key in keys:
if key <= end_key:
value = await self.get(key)
if value:
result.append(value)
return result

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SqliteControlPlaneConfig
async def get_provider_impl(config: SqliteControlPlaneConfig, _deps):
from .control_plane import SqliteControlPlane
impl = SqliteControlPlane(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class SqliteControlPlaneConfig(BaseModel):
db_path: str = Field(
description="File path for the sqlite database",
)
table_name: str = Field(
default="llamastack_control_plane",
description="Table into which all the keys will be placed",
)

View file

@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from typing import Any, List, Optional
import aiosqlite
from llama_stack.apis.control_plane import * # noqa: F403
from .config import SqliteControlPlaneConfig
class SqliteControlPlane(ControlPlane):
def __init__(self, config: SqliteControlPlaneConfig):
self.db_path = config.db_path
self.table_name = config.table_name
async def initialize(self):
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key TEXT PRIMARY KEY,
value TEXT,
expiration TIMESTAMP
)
"""
)
await db.commit()
async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None
) -> None:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
(key, json.dumps(value), expiration),
)
await db.commit()
async def get(self, key: str) -> Optional[ControlPlaneValue]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
) as cursor:
row = await cursor.fetchone()
if row is None:
return None
value, expiration = row
return ControlPlaneValue(
key=key, value=json.loads(value), expiration=expiration
)
async def delete(self, key: str) -> None:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit()
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
) as cursor:
result = []
async for row in cursor:
key, value, expiration = row
result.append(
ControlPlaneValue(
key=key, value=json.loads(value), expiration=expiration
)
)
return result

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import Any, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@json_schema_type
class ControlPlaneValue(BaseModel):
key: str
value: Any
expiration: Optional[datetime] = None
@json_schema_type
class ControlPlane(Protocol):
@webmethod(route="/control_plane/set")
async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None
) -> None: ...
@webmethod(route="/control_plane/get", method="GET")
async def get(self, key: str) -> Optional[ControlPlaneValue]: ...
@webmethod(route="/control_plane/delete")
async def delete(self, key: str) -> None: ...
@webmethod(route="/control_plane/range", method="GET")
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: ...

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.control_plane,
provider_id="sqlite",
pip_packages=["aiosqlite"],
module="llama_stack.providers.impls.sqlite.control_plane",
config_class="llama_stack.providers.impls.sqlite.control_plane.SqliteControlPlaneConfig",
),
remote_provider_spec(
Api.control_plane,
AdapterSpec(
adapter_id="redis",
pip_packages=["redis"],
module="llama_stack.providers.adapters.control_plane.redis",
),
),
]

View file

@ -6,7 +6,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
@ -17,7 +17,7 @@ from pydantic import BaseModel, Field, validator
class Api(Enum): class Api(Enum):
inference = "inference" inference = "inference"
safety = "safety" safety = "safety"
agentic_system = "agentic_system" agents = "agents"
memory = "memory" memory = "memory"
telemetry = "telemetry" telemetry = "telemetry"
@ -43,6 +43,33 @@ class ProviderSpec(BaseModel):
) )
@json_schema_type
class RouterProviderSpec(ProviderSpec):
provider_id: str = "router"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on RouterProviderSpec")
class GenericProviderConfig(BaseModel):
provider_id: str
config: Dict[str, Any]
@json_schema_type @json_schema_type
class AdapterSpec(BaseModel): class AdapterSpec(BaseModel):
adapter_id: str = Field( adapter_id: str = Field(
@ -124,7 +151,7 @@ as being "Llama Stack compatible"
def module(self) -> str: def module(self) -> str:
if self.adapter: if self.adapter:
return self.adapter.module return self.adapter.module
return f"llama_toolchain.{self.api.value}.client" return f"llama_stack.apis.{self.api.value}.client"
@property @property
def pip_packages(self) -> List[str]: def pip_packages(self) -> List[str]:
@ -140,7 +167,7 @@ def remote_provider_spec(
config_class = ( config_class = (
adapter.config_class adapter.config_class
if adapter and adapter.config_class if adapter and adapter.config_class
else "llama_toolchain.core.datatypes.RemoteProviderConfig" else "llama_stack.distribution.datatypes.RemoteProviderConfig"
) )
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote" provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
@ -156,12 +183,23 @@ class DistributionSpec(BaseModel):
description="Description of the distribution", description="Description of the distribution",
) )
docker_image: Optional[str] = None docker_image: Optional[str] = None
providers: Dict[str, str] = Field( providers: Dict[str, Union[str, List[str]]] = Field(
default_factory=dict, default_factory=dict,
description="Provider Types for each of the APIs provided by this distribution", description="""
Provider Types for each of the APIs provided by this distribution. If you
select multiple providers, you should provide an appropriate 'routing_map'
in the runtime configuration to help route to the correct provider.""",
) )
@json_schema_type
class ProviderRoutingEntry(GenericProviderConfig):
routing_key: str
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
@json_schema_type @json_schema_type
class StackRunConfig(BaseModel): class StackRunConfig(BaseModel):
built_at: datetime built_at: datetime
@ -181,12 +219,22 @@ this could be just a hash
default=None, default=None,
description="Reference to the conda environment if this package refers to a conda environment", description="Reference to the conda environment if this package refers to a conda environment",
) )
providers: Dict[str, Any] = Field( apis_to_serve: List[str] = Field(
default_factory=dict,
description=""" description="""
Provider configurations for each of the APIs provided by this package. This includes configurations for The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
the dependencies of these providers as well. )
""", provider_map: Dict[str, ProviderMapEntry] = Field(
description="""
Provider configurations for each of the APIs provided by this package.
Given an API, you can specify a single provider or a "routing table". Each entry in the routing
table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific.
As examples:
- the "inference" API interprets the routing_key as a "model"
- the "memory" API interprets the routing_key as the type of a "memory bank"
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
) )

View file

@ -8,18 +8,19 @@ import importlib
import inspect import inspect
from typing import Dict, List from typing import Dict, List
from llama_toolchain.agentic_system.api import AgenticSystem from llama_stack.apis.agents import Agents
from llama_toolchain.inference.api import Inference from llama_stack.apis.inference import Inference
from llama_toolchain.memory.api import Memory from llama_stack.apis.memory import Memory
from llama_toolchain.safety.api import Safety from llama_stack.apis.safety import Safety
from llama_toolchain.telemetry.api import Telemetry from llama_stack.apis.telemetry import Telemetry
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
# These are the dependencies needed by the distribution server. # These are the dependencies needed by the distribution server.
# `llama-toolchain` is automatically installed by the installation script. # `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [ SERVER_DEPENDENCIES = [
"fastapi", "fastapi",
"fire",
"uvicorn", "uvicorn",
] ]
@ -34,7 +35,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
protocols = { protocols = {
Api.inference: Inference, Api.inference: Inference,
Api.safety: Safety, Api.safety: Safety,
Api.agentic_system: AgenticSystem, Api.agents: Agents,
Api.memory: Memory, Api.memory: Memory,
Api.telemetry: Telemetry, Api.telemetry: Telemetry,
} }
@ -67,7 +68,7 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {} ret = {}
for api in stack_apis(): for api in stack_apis():
name = api.name.lower() name = api.name.lower()
module = importlib.import_module(f"llama_toolchain.{name}.providers") module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = { ret[api] = {
"remote": remote_provider_spec(api), "remote": remote_provider_spec(api),
**{a.provider_id: a for a in module.available_providers()}, **{a.provider_id: a for a in module.available_providers()},

View file

@ -0,0 +1,10 @@
name: local-conda-example
distribution_spec:
description: Use code from `llama_stack` itself to serve all llama stack APIs
providers:
inference: meta-reference
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -3,8 +3,8 @@ distribution_spec:
description: Use Fireworks.ai for running LLM inference description: Use Fireworks.ai for running LLM inference
providers: providers:
inference: remote::fireworks inference: remote::fireworks
memory: meta-reference-faiss memory: meta-reference
safety: meta-reference safety: meta-reference
agentic_system: meta-reference agents: meta-reference
telemetry: console telemetry: meta-reference
image_type: conda image_type: conda

View file

@ -3,8 +3,8 @@ distribution_spec:
description: Like local, but use ollama for running LLM inference description: Like local, but use ollama for running LLM inference
providers: providers:
inference: remote::ollama inference: remote::ollama
memory: meta-reference-faiss memory: meta-reference
safety: meta-reference safety: meta-reference
agentic_system: meta-reference agents: meta-reference
telemetry: console telemetry: meta-reference
image_type: conda image_type: conda

View file

@ -3,8 +3,8 @@ distribution_spec:
description: Use TGI (local or with Hugging Face Inference Endpoints for running LLM inference. When using HF Inference Endpoints, you must provide the name of the endpoint). description: Use TGI (local or with Hugging Face Inference Endpoints for running LLM inference. When using HF Inference Endpoints, you must provide the name of the endpoint).
providers: providers:
inference: remote::tgi inference: remote::tgi
memory: meta-reference-faiss memory: meta-reference
safety: meta-reference safety: meta-reference
agentic_system: meta-reference agents: meta-reference
telemetry: console telemetry: meta-reference
image_type: conda image_type: conda

View file

@ -3,8 +3,8 @@ distribution_spec:
description: Use Together.ai for running LLM inference description: Use Together.ai for running LLM inference
providers: providers:
inference: remote::together inference: remote::together
memory: meta-reference-faiss memory: meta-reference
safety: meta-reference safety: meta-reference
agentic_system: meta-reference agents: meta-reference
telemetry: console telemetry: meta-reference
image_type: conda image_type: conda

View file

@ -0,0 +1,10 @@
name: local-docker-example
distribution_spec:
description: Use code from `llama_stack` itself to serve all llama stack APIs
providers:
inference: meta-reference
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: docker

View file

@ -9,6 +9,7 @@ import inspect
import json import json
import signal import signal
import traceback import traceback
from collections.abc import ( from collections.abc import (
AsyncGenerator as AsyncGeneratorABC, AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC, AsyncIterator as AsyncIteratorABC,
@ -38,16 +39,16 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_toolchain.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
setup_logger, setup_logger,
SpanStatus, SpanStatus,
start_trace, start_trace,
) )
from llama_stack.distribution.datatypes import * # noqa: F403
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec from llama_stack.distribution.distribution import api_endpoints, api_providers
from .distribution import api_endpoints, api_providers from llama_stack.distribution.utils.dynamic import instantiate_provider
from .dynamic import instantiate_provider
def is_async_iterator_type(typ): def is_async_iterator_type(typ):
@ -271,61 +272,80 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
return [by_id[x] for x in stack] return [by_id[x] for x in stack]
def resolve_impls( def snake_to_camel(snake_str):
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any] return "".join(word.capitalize() for word in snake_str.split("_"))
) -> Dict[Api, Any]:
provider_configs = config["providers"]
provider_specs = topological_sort(provider_specs.values())
impls = {}
for provider_spec in provider_specs: async def resolve_impls(
api = provider_spec.api provider_map: Dict[str, ProviderMapEntry],
if api.value not in provider_configs: ) -> Dict[Api, Any]:
raise ValueError( """
f"Could not find provider_spec config for {api}. Please add it to the config" Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_providers = api_providers()
specs = {}
for api_str, item in provider_map.items():
api = Api(api_str)
providers = all_providers[api]
if isinstance(item, GenericProviderConfig):
if item.provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[item.provider_id]
else:
assert isinstance(item, list)
inner_specs = []
for rt_entry in item:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_stack.providers.routers.{api.value.lower()}",
api_dependencies=[],
inner_specs=inner_specs,
) )
if isinstance(provider_spec, InlineProviderSpec): sorted_specs = topological_sort(specs.values())
deps = {api: impls[api] for api in provider_spec.api_dependencies}
else: impls = {}
deps = {} for spec in sorted_specs:
provider_config = provider_configs[api.value] api = spec.api
impl = instantiate_provider(provider_spec, provider_config, deps)
deps = {api: impls[api] for api in spec.api_dependencies}
impl = await instantiate_provider(spec, deps, provider_map[api.value])
impls[api] = impl impls[api] = impl
return impls return impls, specs
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp: with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp) config = StackRunConfig(**yaml.safe_load(fp))
app = FastAPI() app = FastAPI()
all_endpoints = api_endpoints() impls, specs = asyncio.run(resolve_impls(config.provider_map))
all_providers = api_providers()
provider_specs = {}
for api_str, provider_config in config["providers"].items():
api = Api(api_str)
providers = all_providers[api]
provider_id = provider_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_specs[api] = providers[provider_id]
impls = resolve_impls(provider_specs, config)
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
for provider_spec in provider_specs.values(): all_endpoints = api_endpoints()
api = provider_spec.api
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api] endpoints = all_endpoints[api]
impl = impls[api] impl = impls[api]
provider_spec = specs[api]
if ( if (
isinstance(provider_spec, RemoteProviderSpec) isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None and provider_spec.adapter is None

View file

@ -37,6 +37,6 @@ eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name" conda deactivate && conda activate "$env_name"
$CONDA_PREFIX/bin/python \ $CONDA_PREFIX/bin/python \
-m llama_toolchain.core.server \ -m llama_stack.distribution.server.server \
--yaml_config "$yaml_config" \ --yaml_config "$yaml_config" \
--port "$port" "$@" --port "$port" "$@"

View file

@ -38,6 +38,6 @@ podman run -it \
-p $port:$port \ -p $port:$port \
-v "$yaml_config:/app/config.yaml" \ -v "$yaml_config:/app/config.yaml" \
$docker_image \ $docker_image \
python -m llama_toolchain.core.server \ python -m llama_stack.distribution.server.server \
--yaml_config /app/config.yaml \ --yaml_config /app/config.yaml \
--port $port "$@" --port $port "$@"

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
from typing import Any, Dict
from llama_stack.distribution.datatypes import * # noqa: F403
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
deps: Dict[str, Any],
provider_config: ProviderMapEntry,
):
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, RouterProviderSpec):
method = "get_router_impl"
assert isinstance(provider_config, list)
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in provider_config:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_id],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [inner_impls, deps]
else:
method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -27,6 +27,12 @@ def is_list_of_primitives(field_type):
return False return False
def is_basemodel_without_fields(typ):
return (
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
)
def can_recurse(typ): def can_recurse(typ):
return ( return (
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0 inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
@ -151,6 +157,11 @@ def prompt_for_config(
if get_origin(field_type) is Literal: if get_origin(field_type) is Literal:
continue continue
# Skip fields with no type annotations
if is_basemodel_without_fields(field_type):
config_data[field_name] = field_type()
continue
if inspect.isclass(field_type) and issubclass(field_type, Enum): if inspect.isclass(field_type) and issubclass(field_type, Enum):
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):" prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
while True: while True:
@ -254,6 +265,20 @@ def prompt_for_config(
print(f"{str(e)}") print(f"{str(e)}")
continue continue
elif get_origin(field_type) is dict:
try:
value = json.loads(user_input)
if not isinstance(value, dict):
raise ValueError(
"Input must be a JSON-encoded dictionary"
)
except json.JSONDecodeError:
print(
"Invalid JSON. Please enter a valid JSON-encoded dict."
)
continue
# Convert the input to the correct type # Convert the input to the correct type
elif inspect.isclass(field_type) and issubclass( elif inspect.isclass(field_type) and issubclass(
field_type, BaseModel field_type, BaseModel

Some files were not shown because too many files have changed in this diff Show more