mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
API Updates (#73)
* API Keys passed from Client instead of distro configuration * delete distribution registry * Rename the "package" word away * Introduce a "Router" layer for providers Some providers need to be factorized and considered as thin routing layers on top of other providers. Consider two examples: - The inference API should be a routing layer over inference providers, routed using the "model" key - The memory banks API is another instance where various memory bank types will be provided by independent providers (e.g., a vector store is served by Chroma while a keyvalue memory can be served by Redis or PGVector) This commit introduces a generalized routing layer for this purpose. * update `apis_to_serve` * llama_toolchain -> llama_stack * Codemod from llama_toolchain -> llama_stack - added providers/registry - cleaned up api/ subdirectories and moved impls away - restructured api/api.py - from llama_stack.apis.<api> import foo should work now - update imports to do llama_stack.apis.<api> - update many other imports - added __init__, fixed some registry imports - updated registry imports - create_agentic_system -> create_agent - AgenticSystem -> Agent * Moved some stuff out of common/; re-generated OpenAPI spec * llama-toolchain -> llama-stack (hyphens) * add control plane API * add redis adapter + sqlite provider * move core -> distribution * Some more toolchain -> stack changes * small naming shenanigans * Removing custom tool and agent utilities and moving them client side * Move control plane to distribution server for now * Remove control plane from API list * no codeshield dependency randomly plzzzzz * Add "fire" as a dependency * add back event loggers * stack configure fixes * use brave instead of bing in the example client * add init file so it gets packaged * add init files so it gets packaged * Update MANIFEST * bug fix --------- Co-authored-by: Hardik Shah <hjshah@fb.com> Co-authored-by: Xi Yan <xiyan@meta.com> Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
parent
f294eac5f5
commit
9487ad8294
213 changed files with 1725 additions and 1204 deletions
|
@ -1,4 +1,5 @@
|
||||||
include requirements.txt
|
include requirements.txt
|
||||||
include llama_toolchain/data/*.yaml
|
include llama_stack/distribution/*.sh
|
||||||
include llama_toolchain/core/*.sh
|
include llama_stack/cli/scripts/*.sh
|
||||||
include llama_toolchain/cli/scripts/*.sh
|
include llama_stack/distribution/example_configs/conda/*.yaml
|
||||||
|
include llama_stack/distribution/example_configs/docker/*.yaml
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# llama-stack
|
# llama-stack
|
||||||
|
|
||||||
[](https://pypi.org/project/llama-toolchain/)
|
[](https://pypi.org/project/llama-stack/)
|
||||||
[](https://discord.gg/TZAAYNVtrU)
|
[](https://discord.gg/TZAAYNVtrU)
|
||||||
|
|
||||||
This repository contains the specifications and implementations of the APIs which are part of the Llama Stack.
|
This repository contains the specifications and implementations of the APIs which are part of the Llama Stack.
|
||||||
|
@ -42,7 +42,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install this repository as a [package](https://pypi.org/project/llama-toolchain/) with `pip install llama-toolchain`
|
You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack`
|
||||||
|
|
||||||
If you want to install from source:
|
If you want to install from source:
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Llama CLI Reference
|
# Llama CLI Reference
|
||||||
|
|
||||||
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-toolchain` package.
|
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
|
||||||
|
|
||||||
### Subcommands
|
### Subcommands
|
||||||
1. `download`: `llama` cli tools supports downloading the model from Meta or HuggingFace.
|
1. `download`: `llama` cli tools supports downloading the model from Meta or HuggingFace.
|
||||||
|
@ -276,16 +276,16 @@ The following command and specifications allows you to get started with building
|
||||||
```
|
```
|
||||||
llama stack build <path/to/config>
|
llama stack build <path/to/config>
|
||||||
```
|
```
|
||||||
- You will be required to pass in a file path to the build.config file (e.g. `./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_toolchain/configs/distributions/` folder.
|
- You will be required to pass in a file path to the build.config file (e.g. `./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_stack/distribution/example_configs/` folder.
|
||||||
|
|
||||||
The file will be of the contents
|
The file will be of the contents
|
||||||
```
|
```
|
||||||
$ cat ./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml
|
$ cat ./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml
|
||||||
|
|
||||||
name: 8b-instruct
|
name: 8b-instruct
|
||||||
distribution_spec:
|
distribution_spec:
|
||||||
distribution_type: local
|
distribution_type: local
|
||||||
description: Use code from `llama_toolchain` itself to serve all llama stack APIs
|
description: Use code from `llama_stack` itself to serve all llama stack APIs
|
||||||
docker_image: null
|
docker_image: null
|
||||||
providers:
|
providers:
|
||||||
inference: meta-reference
|
inference: meta-reference
|
||||||
|
@ -311,7 +311,7 @@ After this step is complete, a file named `8b-instruct-build.yaml` will be gener
|
||||||
To specify a different API provider, we can change the `distribution_spec` in our `<name>-build.yaml` config. For example, the following build spec allows you to build a distribution using TGI as the inference API provider.
|
To specify a different API provider, we can change the `distribution_spec` in our `<name>-build.yaml` config. For example, the following build spec allows you to build a distribution using TGI as the inference API provider.
|
||||||
|
|
||||||
```
|
```
|
||||||
$ cat ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml
|
$ cat ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml
|
||||||
|
|
||||||
name: local-tgi-conda-example
|
name: local-tgi-conda-example
|
||||||
distribution_spec:
|
distribution_spec:
|
||||||
|
@ -328,7 +328,7 @@ image_type: conda
|
||||||
|
|
||||||
The following command allows you to build a distribution with TGI as the inference API provider, with the name `tgi`.
|
The following command allows you to build a distribution with TGI as the inference API provider, with the name `tgi`.
|
||||||
```
|
```
|
||||||
llama stack build --config ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml --name tgi
|
llama stack build --config ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml --name tgi
|
||||||
```
|
```
|
||||||
|
|
||||||
We provide some example build configs to help you get started with building with different API providers.
|
We provide some example build configs to help you get started with building with different API providers.
|
||||||
|
@ -337,11 +337,11 @@ We provide some example build configs to help you get started with building with
|
||||||
To build a docker image, simply change the `image_type` to `docker` in our `<name>-build.yaml` file, and run `llama stack build --config <name>-build.yaml`.
|
To build a docker image, simply change the `image_type` to `docker` in our `<name>-build.yaml` file, and run `llama stack build --config <name>-build.yaml`.
|
||||||
|
|
||||||
```
|
```
|
||||||
$ cat ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml
|
$ cat ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml
|
||||||
|
|
||||||
name: local-docker-example
|
name: local-docker-example
|
||||||
distribution_spec:
|
distribution_spec:
|
||||||
description: Use code from `llama_toolchain` itself to serve all llama stack APIs
|
description: Use code from `llama_stack` itself to serve all llama stack APIs
|
||||||
docker_image: null
|
docker_image: null
|
||||||
providers:
|
providers:
|
||||||
inference: meta-reference
|
inference: meta-reference
|
||||||
|
@ -354,7 +354,7 @@ image_type: docker
|
||||||
|
|
||||||
The following command allows you to build a Docker image with the name `docker-local`
|
The following command allows you to build a Docker image with the name `docker-local`
|
||||||
```
|
```
|
||||||
llama stack build --config ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml --name docker-local
|
llama stack build --config ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml --name docker-local
|
||||||
|
|
||||||
Dockerfile created successfully in /tmp/tmp.I0ifS2c46A/DockerfileFROM python:3.10-slim
|
Dockerfile created successfully in /tmp/tmp.I0ifS2c46A/DockerfileFROM python:3.10-slim
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
@ -480,9 +480,9 @@ This server is running a Llama model locally.
|
||||||
Once the server is setup, we can test it with a client to see the example outputs.
|
Once the server is setup, we can test it with a client to see the example outputs.
|
||||||
```
|
```
|
||||||
cd /path/to/llama-stack
|
cd /path/to/llama-stack
|
||||||
conda activate <env> # any environment containing the llama-toolchain pip package will work
|
conda activate <env> # any environment containing the llama-stack pip package will work
|
||||||
|
|
||||||
python -m llama_toolchain.inference.client localhost 5000
|
python -m llama_stack.apis.inference.client localhost 5000
|
||||||
```
|
```
|
||||||
|
|
||||||
This will run the chat completion client and query the distribution’s /inference/chat_completion API.
|
This will run the chat completion client and query the distribution’s /inference/chat_completion API.
|
||||||
|
@ -500,7 +500,7 @@ You know what's even more hilarious? People like you who think they can just Goo
|
||||||
Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by:
|
Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by:
|
||||||
|
|
||||||
```
|
```
|
||||||
python -m llama_toolchain.safety.client localhost 5000
|
python -m llama_stack.safety.client localhost 5000
|
||||||
```
|
```
|
||||||
|
|
||||||
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.
|
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Getting Started
|
# Getting Started
|
||||||
|
|
||||||
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-toolchain` package.
|
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
|
||||||
|
|
||||||
This guides allows you to quickly get started with building and running a Llama Stack server in < 5 minutes!
|
This guides allows you to quickly get started with building and running a Llama Stack server in < 5 minutes!
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ This guides allows you to quickly get started with building and running a Llama
|
||||||
|
|
||||||
**`llama stack build`**
|
**`llama stack build`**
|
||||||
```
|
```
|
||||||
llama stack build --config ./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml --name my-local-llama-stack
|
llama stack build --config ./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml --name my-local-llama-stack
|
||||||
...
|
...
|
||||||
...
|
...
|
||||||
Build spec configuration saved at ~/.llama/distributions/conda/my-local-llama-stack-build.yaml
|
Build spec configuration saved at ~/.llama/distributions/conda/my-local-llama-stack-build.yaml
|
||||||
|
@ -97,16 +97,16 @@ The following command and specifications allows you to get started with building
|
||||||
```
|
```
|
||||||
llama stack build <path/to/config>
|
llama stack build <path/to/config>
|
||||||
```
|
```
|
||||||
- You will be required to pass in a file path to the build.config file (e.g. `./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_toolchain/configs/distributions/` folder.
|
- You will be required to pass in a file path to the build.config file (e.g. `./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_stack/distribution/example_configs/` folder.
|
||||||
|
|
||||||
The file will be of the contents
|
The file will be of the contents
|
||||||
```
|
```
|
||||||
$ cat ./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml
|
$ cat ./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml
|
||||||
|
|
||||||
name: 8b-instruct
|
name: 8b-instruct
|
||||||
distribution_spec:
|
distribution_spec:
|
||||||
distribution_type: local
|
distribution_type: local
|
||||||
description: Use code from `llama_toolchain` itself to serve all llama stack APIs
|
description: Use code from `llama_stack` itself to serve all llama stack APIs
|
||||||
docker_image: null
|
docker_image: null
|
||||||
providers:
|
providers:
|
||||||
inference: meta-reference
|
inference: meta-reference
|
||||||
|
@ -132,7 +132,7 @@ After this step is complete, a file named `8b-instruct-build.yaml` will be gener
|
||||||
To specify a different API provider, we can change the `distribution_spec` in our `<name>-build.yaml` config. For example, the following build spec allows you to build a distribution using TGI as the inference API provider.
|
To specify a different API provider, we can change the `distribution_spec` in our `<name>-build.yaml` config. For example, the following build spec allows you to build a distribution using TGI as the inference API provider.
|
||||||
|
|
||||||
```
|
```
|
||||||
$ cat ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml
|
$ cat ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml
|
||||||
|
|
||||||
name: local-tgi-conda-example
|
name: local-tgi-conda-example
|
||||||
distribution_spec:
|
distribution_spec:
|
||||||
|
@ -149,7 +149,7 @@ image_type: conda
|
||||||
|
|
||||||
The following command allows you to build a distribution with TGI as the inference API provider, with the name `tgi`.
|
The following command allows you to build a distribution with TGI as the inference API provider, with the name `tgi`.
|
||||||
```
|
```
|
||||||
llama stack build --config ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml --name tgi
|
llama stack build --config ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml --name tgi
|
||||||
```
|
```
|
||||||
|
|
||||||
We provide some example build configs to help you get started with building with different API providers.
|
We provide some example build configs to help you get started with building with different API providers.
|
||||||
|
@ -158,11 +158,11 @@ We provide some example build configs to help you get started with building with
|
||||||
To build a docker image, simply change the `image_type` to `docker` in our `<name>-build.yaml` file, and run `llama stack build --config <name>-build.yaml`.
|
To build a docker image, simply change the `image_type` to `docker` in our `<name>-build.yaml` file, and run `llama stack build --config <name>-build.yaml`.
|
||||||
|
|
||||||
```
|
```
|
||||||
$ cat ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml
|
$ cat ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml
|
||||||
|
|
||||||
name: local-docker-example
|
name: local-docker-example
|
||||||
distribution_spec:
|
distribution_spec:
|
||||||
description: Use code from `llama_toolchain` itself to serve all llama stack APIs
|
description: Use code from `llama_stack` itself to serve all llama stack APIs
|
||||||
docker_image: null
|
docker_image: null
|
||||||
providers:
|
providers:
|
||||||
inference: meta-reference
|
inference: meta-reference
|
||||||
|
@ -175,7 +175,7 @@ image_type: docker
|
||||||
|
|
||||||
The following command allows you to build a Docker image with the name `docker-local`
|
The following command allows you to build a Docker image with the name `docker-local`
|
||||||
```
|
```
|
||||||
llama stack build --config ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml --name docker-local
|
llama stack build --config ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml --name docker-local
|
||||||
|
|
||||||
Dockerfile created successfully in /tmp/tmp.I0ifS2c46A/DockerfileFROM python:3.10-slim
|
Dockerfile created successfully in /tmp/tmp.I0ifS2c46A/DockerfileFROM python:3.10-slim
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
@ -294,9 +294,9 @@ This server is running a Llama model locally.
|
||||||
Once the server is setup, we can test it with a client to see the example outputs.
|
Once the server is setup, we can test it with a client to see the example outputs.
|
||||||
```
|
```
|
||||||
cd /path/to/llama-stack
|
cd /path/to/llama-stack
|
||||||
conda activate <env> # any environment containing the llama-toolchain pip package will work
|
conda activate <env> # any environment containing the llama-stack pip package will work
|
||||||
|
|
||||||
python -m llama_toolchain.inference.client localhost 5000
|
python -m llama_stack.apis.inference.client localhost 5000
|
||||||
```
|
```
|
||||||
|
|
||||||
This will run the chat completion client and query the distribution’s /inference/chat_completion API.
|
This will run the chat completion client and query the distribution’s /inference/chat_completion API.
|
||||||
|
@ -314,7 +314,7 @@ You know what's even more hilarious? People like you who think they can just Goo
|
||||||
Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by:
|
Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by:
|
||||||
|
|
||||||
```
|
```
|
||||||
python -m llama_toolchain.safety.client localhost 5000
|
python -m llama_stack.apis.safety.client localhost 5000
|
||||||
```
|
```
|
||||||
|
|
||||||
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.
|
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.
|
||||||
|
|
7
llama_stack/apis/agents/__init__.py
Normal file
7
llama_stack/apis/agents/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .agents import * # noqa: F401 F403
|
|
@ -14,10 +14,10 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.common.deployment_types import * # noqa: F403
|
from llama_stack.apis.common.deployment_types import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_toolchain.safety.api import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_toolchain.memory.api import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -26,7 +26,7 @@ class Attachment(BaseModel):
|
||||||
mime_type: str
|
mime_type: str
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemTool(Enum):
|
class AgentTool(Enum):
|
||||||
brave_search = "brave_search"
|
brave_search = "brave_search"
|
||||||
wolfram_alpha = "wolfram_alpha"
|
wolfram_alpha = "wolfram_alpha"
|
||||||
photogen = "photogen"
|
photogen = "photogen"
|
||||||
|
@ -50,41 +50,35 @@ class SearchEngineType(Enum):
|
||||||
class SearchToolDefinition(ToolDefinitionCommon):
|
class SearchToolDefinition(ToolDefinitionCommon):
|
||||||
# NOTE: brave_search is just a placeholder since model always uses
|
# NOTE: brave_search is just a placeholder since model always uses
|
||||||
# brave_search as tool call name
|
# brave_search as tool call name
|
||||||
type: Literal[AgenticSystemTool.brave_search.value] = (
|
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
|
||||||
AgenticSystemTool.brave_search.value
|
api_key: str
|
||||||
)
|
|
||||||
engine: SearchEngineType = SearchEngineType.brave
|
engine: SearchEngineType = SearchEngineType.brave
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
||||||
type: Literal[AgenticSystemTool.wolfram_alpha.value] = (
|
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
|
||||||
AgenticSystemTool.wolfram_alpha.value
|
api_key: str
|
||||||
)
|
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PhotogenToolDefinition(ToolDefinitionCommon):
|
class PhotogenToolDefinition(ToolDefinitionCommon):
|
||||||
type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value
|
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
||||||
type: Literal[AgenticSystemTool.code_interpreter.value] = (
|
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
|
||||||
AgenticSystemTool.code_interpreter.value
|
|
||||||
)
|
|
||||||
enable_inline_code_execution: bool = True
|
enable_inline_code_execution: bool = True
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
||||||
type: Literal[AgenticSystemTool.function_call.value] = (
|
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
|
||||||
AgenticSystemTool.function_call.value
|
|
||||||
)
|
|
||||||
function_name: str
|
function_name: str
|
||||||
description: str
|
description: str
|
||||||
parameters: Dict[str, ToolParamDefinition]
|
parameters: Dict[str, ToolParamDefinition]
|
||||||
|
@ -95,30 +89,30 @@ class _MemoryBankConfigCommon(BaseModel):
|
||||||
bank_id: str
|
bank_id: str
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||||
keys: List[str] # what keys to focus on
|
keys: List[str] # what keys to focus on
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||||
entities: List[str] # what entities to focus on
|
entities: List[str] # what entities to focus on
|
||||||
|
|
||||||
|
|
||||||
MemoryBankConfig = Annotated[
|
MemoryBankConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
AgenticSystemVectorMemoryBankConfig,
|
AgentVectorMemoryBankConfig,
|
||||||
AgenticSystemKeyValueMemoryBankConfig,
|
AgentKeyValueMemoryBankConfig,
|
||||||
AgenticSystemKeywordMemoryBankConfig,
|
AgentKeywordMemoryBankConfig,
|
||||||
AgenticSystemGraphMemoryBankConfig,
|
AgentGraphMemoryBankConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
@ -158,7 +152,7 @@ MemoryQueryGeneratorConfig = Annotated[
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
class MemoryToolDefinition(ToolDefinitionCommon):
|
||||||
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
|
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
|
||||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||||
# This config defines how a query is generated using the messages
|
# This config defines how a query is generated using the messages
|
||||||
# for memory bank retrieval.
|
# for memory bank retrieval.
|
||||||
|
@ -169,7 +163,7 @@ class MemoryToolDefinition(ToolDefinitionCommon):
|
||||||
max_chunks: int = 10
|
max_chunks: int = 10
|
||||||
|
|
||||||
|
|
||||||
AgenticSystemToolDefinition = Annotated[
|
AgentToolDefinition = Annotated[
|
||||||
Union[
|
Union[
|
||||||
SearchToolDefinition,
|
SearchToolDefinition,
|
||||||
WolframAlphaToolDefinition,
|
WolframAlphaToolDefinition,
|
||||||
|
@ -275,7 +269,7 @@ class AgentConfigCommon(BaseModel):
|
||||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
|
|
||||||
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
|
@ -292,7 +286,7 @@ class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||||
instructions: Optional[str] = None
|
instructions: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemTurnResponseEventType(Enum):
|
class AgentTurnResponseEventType(Enum):
|
||||||
step_start = "step_start"
|
step_start = "step_start"
|
||||||
step_complete = "step_complete"
|
step_complete = "step_complete"
|
||||||
step_progress = "step_progress"
|
step_progress = "step_progress"
|
||||||
|
@ -302,9 +296,9 @@ class AgenticSystemTurnResponseEventType(Enum):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
|
event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
|
||||||
AgenticSystemTurnResponseEventType.step_start.value
|
AgentTurnResponseEventType.step_start.value
|
||||||
)
|
)
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
|
@ -312,20 +306,20 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
|
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
|
||||||
AgenticSystemTurnResponseEventType.step_complete.value
|
AgentTurnResponseEventType.step_complete.value
|
||||||
)
|
)
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_details: Step
|
step_details: Step
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
|
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
|
||||||
AgenticSystemTurnResponseEventType.step_progress.value
|
AgentTurnResponseEventType.step_progress.value
|
||||||
)
|
)
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
|
@ -336,49 +330,49 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
|
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
|
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
|
||||||
AgenticSystemTurnResponseEventType.turn_start.value
|
AgentTurnResponseEventType.turn_start.value
|
||||||
)
|
)
|
||||||
turn_id: str
|
turn_id: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
|
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
|
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
|
||||||
AgenticSystemTurnResponseEventType.turn_complete.value
|
AgentTurnResponseEventType.turn_complete.value
|
||||||
)
|
)
|
||||||
turn: Turn
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseEvent(BaseModel):
|
class AgentTurnResponseEvent(BaseModel):
|
||||||
"""Streamed agent execution response."""
|
"""Streamed agent execution response."""
|
||||||
|
|
||||||
payload: Annotated[
|
payload: Annotated[
|
||||||
Union[
|
Union[
|
||||||
AgenticSystemTurnResponseStepStartPayload,
|
AgentTurnResponseStepStartPayload,
|
||||||
AgenticSystemTurnResponseStepProgressPayload,
|
AgentTurnResponseStepProgressPayload,
|
||||||
AgenticSystemTurnResponseStepCompletePayload,
|
AgentTurnResponseStepCompletePayload,
|
||||||
AgenticSystemTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
AgenticSystemTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
],
|
],
|
||||||
Field(discriminator="event_type"),
|
Field(discriminator="event_type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemCreateResponse(BaseModel):
|
class AgentCreateResponse(BaseModel):
|
||||||
agent_id: str
|
agent_id: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemSessionCreateResponse(BaseModel):
|
class AgentSessionCreateResponse(BaseModel):
|
||||||
session_id: str
|
session_id: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
|
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
agent_id: str
|
agent_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
|
|
||||||
|
@ -397,24 +391,24 @@ class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseStreamChunk(BaseModel):
|
class AgentTurnResponseStreamChunk(BaseModel):
|
||||||
event: AgenticSystemTurnResponseEvent
|
event: AgentTurnResponseEvent
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemStepResponse(BaseModel):
|
class AgentStepResponse(BaseModel):
|
||||||
step: Step
|
step: Step
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystem(Protocol):
|
class Agents(Protocol):
|
||||||
@webmethod(route="/agentic_system/create")
|
@webmethod(route="/agents/create")
|
||||||
async def create_agentic_system(
|
async def create_agent(
|
||||||
self,
|
self,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
) -> AgenticSystemCreateResponse: ...
|
) -> AgentCreateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/turn/create")
|
@webmethod(route="/agents/turn/create")
|
||||||
async def create_agentic_system_turn(
|
async def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
@ -426,42 +420,40 @@ class AgenticSystem(Protocol):
|
||||||
],
|
],
|
||||||
attachments: Optional[List[Attachment]] = None,
|
attachments: Optional[List[Attachment]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AgenticSystemTurnResponseStreamChunk: ...
|
) -> AgentTurnResponseStreamChunk: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/turn/get")
|
@webmethod(route="/agents/turn/get")
|
||||||
async def get_agentic_system_turn(
|
async def get_agents_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
) -> Turn: ...
|
) -> Turn: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/step/get")
|
@webmethod(route="/agents/step/get")
|
||||||
async def get_agentic_system_step(
|
async def get_agents_step(
|
||||||
self, agent_id: str, turn_id: str, step_id: str
|
self, agent_id: str, turn_id: str, step_id: str
|
||||||
) -> AgenticSystemStepResponse: ...
|
) -> AgentStepResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/create")
|
@webmethod(route="/agents/session/create")
|
||||||
async def create_agentic_system_session(
|
async def create_agent_session(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_name: str,
|
session_name: str,
|
||||||
) -> AgenticSystemSessionCreateResponse: ...
|
) -> AgentSessionCreateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/get")
|
@webmethod(route="/agents/session/get")
|
||||||
async def get_agentic_system_session(
|
async def get_agents_session(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_ids: Optional[List[str]] = None,
|
turn_ids: Optional[List[str]] = None,
|
||||||
) -> Session: ...
|
) -> Session: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/delete")
|
@webmethod(route="/agents/session/delete")
|
||||||
async def delete_agentic_system_session(
|
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ...
|
||||||
self, agent_id: str, session_id: str
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/delete")
|
@webmethod(route="/agents/delete")
|
||||||
async def delete_agentic_system(
|
async def delete_agents(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> None: ...
|
) -> None: ...
|
|
@ -6,56 +6,58 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .api import * # noqa: F403
|
from .agents import * # noqa: F403
|
||||||
from .event_logger import EventLogger
|
from .event_logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(config: RemoteProviderConfig, _deps):
|
async def get_client_impl(config: RemoteProviderConfig, _deps):
|
||||||
return AgenticSystemClient(config.url)
|
return AgentsClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
def encodable_dict(d: BaseModel):
|
def encodable_dict(d: BaseModel):
|
||||||
return json.loads(d.json())
|
return json.loads(d.json())
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemClient(AgenticSystem):
|
class AgentsClient(Agents):
|
||||||
def __init__(self, base_url: str):
|
def __init__(self, base_url: str):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
|
||||||
async def create_agentic_system(
|
async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse:
|
||||||
self, agent_config: AgentConfig
|
|
||||||
) -> AgenticSystemCreateResponse:
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/agentic_system/create",
|
f"{self.base_url}/agents/create",
|
||||||
json={
|
json={
|
||||||
"agent_config": encodable_dict(agent_config),
|
"agent_config": encodable_dict(agent_config),
|
||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return AgenticSystemCreateResponse(**response.json())
|
return AgentCreateResponse(**response.json())
|
||||||
|
|
||||||
async def create_agentic_system_session(
|
async def create_agent_session(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_name: str,
|
session_name: str,
|
||||||
) -> AgenticSystemSessionCreateResponse:
|
) -> AgentSessionCreateResponse:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/agentic_system/session/create",
|
f"{self.base_url}/agents/session/create",
|
||||||
json={
|
json={
|
||||||
"agent_id": agent_id,
|
"agent_id": agent_id,
|
||||||
"session_name": session_name,
|
"session_name": session_name,
|
||||||
|
@ -63,16 +65,16 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return AgenticSystemSessionCreateResponse(**response.json())
|
return AgentSessionCreateResponse(**response.json())
|
||||||
|
|
||||||
async def create_agentic_system_turn(
|
async def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemTurnCreateRequest,
|
request: AgentTurnCreateRequest,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"{self.base_url}/agentic_system/turn/create",
|
f"{self.base_url}/agents/turn/create",
|
||||||
json=encodable_dict(request),
|
json=encodable_dict(request),
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
|
@ -86,7 +88,7 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
cprint(data, "red")
|
cprint(data, "red")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield AgenticSystemTurnResponseStreamChunk(**jdata)
|
yield AgentTurnResponseStreamChunk(**jdata)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(data)
|
print(data)
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
@ -102,16 +104,16 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
||||||
tool_prompt_format=ToolPromptFormat.function_tag,
|
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||||
)
|
)
|
||||||
|
|
||||||
create_response = await api.create_agentic_system(agent_config)
|
create_response = await api.create_agent(agent_config)
|
||||||
session_response = await api.create_agentic_system_session(
|
session_response = await api.create_agent_session(
|
||||||
agent_id=create_response.agent_id,
|
agent_id=create_response.agent_id,
|
||||||
session_name="test_session",
|
session_name="test_session",
|
||||||
)
|
)
|
||||||
|
|
||||||
for content in user_prompts:
|
for content in user_prompts:
|
||||||
cprint(f"User> {content}", color="white", attrs=["bold"])
|
cprint(f"User> {content}", color="white", attrs=["bold"])
|
||||||
iterator = api.create_agentic_system_turn(
|
iterator = api.create_agent_turn(
|
||||||
AgenticSystemTurnCreateRequest(
|
AgentTurnCreateRequest(
|
||||||
agent_id=create_response.agent_id,
|
agent_id=create_response.agent_id,
|
||||||
session_id=session_response.session_id,
|
session_id=session_response.session_id,
|
||||||
messages=[
|
messages=[
|
||||||
|
@ -128,11 +130,14 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int):
|
async def run_main(host: str, port: int):
|
||||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
api = AgentsClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
SearchToolDefinition(engine=SearchEngineType.bing),
|
SearchToolDefinition(
|
||||||
WolframAlphaToolDefinition(),
|
engine=SearchEngineType.brave,
|
||||||
|
api_key=os.getenv("BRAVE_SEARCH_API_KEY"),
|
||||||
|
),
|
||||||
|
WolframAlphaToolDefinition(api_key=os.getenv("WOLFRAM_ALPHA_API_KEY")),
|
||||||
CodeInterpreterToolDefinition(),
|
CodeInterpreterToolDefinition(),
|
||||||
]
|
]
|
||||||
tool_definitions += [
|
tool_definitions += [
|
||||||
|
@ -165,7 +170,7 @@ async def run_main(host: str, port: int):
|
||||||
|
|
||||||
|
|
||||||
async def run_rag(host: str, port: int):
|
async def run_rag(host: str, port: int):
|
||||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
api = AgentsClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
|
@ -186,7 +191,7 @@ async def run_rag(host: str, port: int):
|
||||||
]
|
]
|
||||||
|
|
||||||
# Alternatively, you can pre-populate the memory bank with documents for example,
|
# Alternatively, you can pre-populate the memory bank with documents for example,
|
||||||
# using `llama_toolchain.memory.client`. Then you can grab the bank_id
|
# using `llama_stack.memory.client`. Then you can grab the bank_id
|
||||||
# from the output of that run.
|
# from the output of that run.
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
MemoryToolDefinition(
|
MemoryToolDefinition(
|
|
@ -9,12 +9,9 @@ from typing import Optional
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
from llama_models.llama3.api.tool_utils import ToolUtils
|
||||||
|
|
||||||
from termcolor import cprint
|
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api import (
|
from termcolor import cprint
|
||||||
AgenticSystemTurnResponseEventType,
|
|
||||||
StepType,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LogEvent:
|
class LogEvent:
|
||||||
|
@ -40,7 +37,7 @@ class LogEvent:
|
||||||
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
|
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
|
||||||
|
|
||||||
|
|
||||||
EventType = AgenticSystemTurnResponseEventType
|
EventType = AgentTurnResponseEventType
|
||||||
|
|
||||||
|
|
||||||
class EventLogger:
|
class EventLogger:
|
7
llama_stack/apis/batch_inference/__init__.py
Normal file
7
llama_stack/apis/batch_inference/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .batch_inference import * # noqa: F401 F403
|
|
@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
7
llama_stack/apis/dataset/__init__.py
Normal file
7
llama_stack/apis/dataset/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .dataset import * # noqa: F401 F403
|
7
llama_stack/apis/evals/__init__.py
Normal file
7
llama_stack/apis/evals/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .evals import * # noqa: F401 F403
|
|
@ -12,8 +12,8 @@ from llama_models.schema_utils import webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.dataset.api import * # noqa: F403
|
from llama_stack.apis.dataset import * # noqa: F403
|
||||||
from llama_toolchain.common.training_types import * # noqa: F403
|
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationMetric(Enum):
|
class TextGenerationMetric(Enum):
|
7
llama_stack/apis/inference/__init__.py
Normal file
7
llama_stack/apis/inference/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .inference import * # noqa: F401 F403
|
|
@ -11,11 +11,13 @@ from typing import Any, AsyncGenerator
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .api import (
|
from .event_logger import EventLogger
|
||||||
|
|
||||||
|
from .inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
|
@ -23,7 +25,6 @@ from .api import (
|
||||||
Inference,
|
Inference,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from .event_logger import EventLogger
|
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_toolchain.inference.api import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
)
|
)
|
7
llama_stack/apis/memory/__init__.py
Normal file
7
llama_stack/apis/memory/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .memory import * # noqa: F401 F403
|
|
@ -6,17 +6,18 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
from .memory import * # noqa: F403
|
||||||
|
|
||||||
from .api import * # noqa: F403
|
|
||||||
from .common.file_utils import data_url_from_file
|
from .common.file_utils import data_url_from_file
|
||||||
|
|
||||||
|
|
7
llama_stack/apis/models/__init__.py
Normal file
7
llama_stack/apis/models/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .models import * # noqa: F401 F403
|
7
llama_stack/apis/post_training/__init__.py
Normal file
7
llama_stack/apis/post_training/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .post_training import * # noqa: F401 F403
|
|
@ -14,8 +14,8 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.dataset.api import * # noqa: F403
|
from llama_stack.apis.dataset import * # noqa: F403
|
||||||
from llama_toolchain.common.training_types import * # noqa: F403
|
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class OptimizerType(Enum):
|
class OptimizerType(Enum):
|
7
llama_stack/apis/reward_scoring/__init__.py
Normal file
7
llama_stack/apis/reward_scoring/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .reward_scoring import * # noqa: F401 F403
|
7
llama_stack/apis/safety/__init__.py
Normal file
7
llama_stack/apis/safety/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .safety import * # noqa: F401 F403
|
|
@ -14,11 +14,11 @@ import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import UserMessage
|
from llama_models.llama3.api.datatypes import UserMessage
|
||||||
|
|
||||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .api import * # noqa: F403
|
from .safety import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
|
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
|
|
@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
|
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
34
llama_stack/apis/stack.py
Normal file
34
llama_stack/apis/stack.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
|
from llama_stack.apis.dataset import * # noqa: F403
|
||||||
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.apis.batch_inference import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
|
from llama_stack.apis.post_training import * # noqa: F403
|
||||||
|
from llama_stack.apis.reward_scoring import * # noqa: F403
|
||||||
|
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
|
||||||
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaStack(
|
||||||
|
Inference,
|
||||||
|
BatchInference,
|
||||||
|
Agents,
|
||||||
|
RewardScoring,
|
||||||
|
Safety,
|
||||||
|
SyntheticDataGeneration,
|
||||||
|
Datasets,
|
||||||
|
Telemetry,
|
||||||
|
PostTraining,
|
||||||
|
Memory,
|
||||||
|
Evaluations,
|
||||||
|
):
|
||||||
|
pass
|
7
llama_stack/apis/synthetic_data_generation/__init__.py
Normal file
7
llama_stack/apis/synthetic_data_generation/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .synthetic_data_generation import * # noqa: F401 F403
|
|
@ -13,7 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.reward_scoring.api import * # noqa: F403
|
from llama_stack.apis.reward_scoring import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class FilteringFunction(Enum):
|
class FilteringFunction(Enum):
|
7
llama_stack/apis/telemetry/__init__.py
Normal file
7
llama_stack/apis/telemetry/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .telemetry import * # noqa: F401 F403
|
|
@ -20,7 +20,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
class Download(Subcommand):
|
class Download(Subcommand):
|
||||||
|
@ -92,7 +92,7 @@ def _hf_download(
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||||
|
|
||||||
from llama_toolchain.common.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
repo_id = model.huggingface_repo
|
repo_id = model.huggingface_repo
|
||||||
if repo_id is None:
|
if repo_id is None:
|
||||||
|
@ -106,7 +106,7 @@ def _hf_download(
|
||||||
local_dir=output_dir,
|
local_dir=output_dir,
|
||||||
ignore_patterns=ignore_patterns,
|
ignore_patterns=ignore_patterns,
|
||||||
token=hf_token,
|
token=hf_token,
|
||||||
library_name="llama-toolchain",
|
library_name="llama-stack",
|
||||||
)
|
)
|
||||||
except GatedRepoError:
|
except GatedRepoError:
|
||||||
parser.error(
|
parser.error(
|
||||||
|
@ -126,7 +126,7 @@ def _hf_download(
|
||||||
def _meta_download(model: "Model", meta_url: str):
|
def _meta_download(model: "Model", meta_url: str):
|
||||||
from llama_models.sku_list import llama_meta_net_info
|
from llama_models.sku_list import llama_meta_net_info
|
||||||
|
|
||||||
from llama_toolchain.common.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
output_dir = Path(model_local_dir(model.descriptor()))
|
output_dir = Path(model_local_dir(model.descriptor()))
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
@ -188,7 +188,7 @@ class Manifest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
def _download_from_manifest(manifest_file: str):
|
def _download_from_manifest(manifest_file: str):
|
||||||
from llama_toolchain.common.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
with open(manifest_file, "r") as f:
|
with open(manifest_file, "r") as f:
|
||||||
d = json.load(f)
|
d = json.load(f)
|
|
@ -31,16 +31,6 @@ class LlamaCLIParser:
|
||||||
ModelParser.create(subparsers)
|
ModelParser.create(subparsers)
|
||||||
StackParser.create(subparsers)
|
StackParser.create(subparsers)
|
||||||
|
|
||||||
# Import sub-commands from agentic_system if they exist
|
|
||||||
try:
|
|
||||||
from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES
|
|
||||||
|
|
||||||
for module in SUBCOMMAND_MODULES:
|
|
||||||
module.create(subparsers)
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def parse_args(self) -> argparse.Namespace:
|
def parse_args(self) -> argparse.Namespace:
|
||||||
return self.parser.parse_args()
|
return self.parser.parse_args()
|
||||||
|
|
|
@ -9,12 +9,12 @@ import json
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
|
||||||
from llama_toolchain.cli.table import print_table
|
|
||||||
from llama_toolchain.common.serialize import EnumEncoder
|
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
from llama_stack.cli.table import print_table
|
||||||
|
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||||
|
|
||||||
|
|
||||||
class ModelDescribe(Subcommand):
|
class ModelDescribe(Subcommand):
|
||||||
"""Show details about a model"""
|
"""Show details about a model"""
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
class ModelDownload(Subcommand):
|
class ModelDownload(Subcommand):
|
||||||
|
@ -19,6 +19,6 @@ class ModelDownload(Subcommand):
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_toolchain.cli.download import setup_download_parser
|
from llama_stack.cli.download import setup_download_parser
|
||||||
|
|
||||||
setup_download_parser(self.parser)
|
setup_download_parser(self.parser)
|
|
@ -8,8 +8,8 @@ import argparse
|
||||||
|
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
|
|
||||||
|
|
||||||
class ModelList(Subcommand):
|
class ModelList(Subcommand):
|
|
@ -6,12 +6,12 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_toolchain.cli.model.describe import ModelDescribe
|
from llama_stack.cli.model.describe import ModelDescribe
|
||||||
from llama_toolchain.cli.model.download import ModelDownload
|
from llama_stack.cli.model.download import ModelDownload
|
||||||
from llama_toolchain.cli.model.list import ModelList
|
from llama_stack.cli.model.list import ModelList
|
||||||
from llama_toolchain.cli.model.template import ModelTemplate
|
from llama_stack.cli.model.template import ModelTemplate
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
class ModelParser(Subcommand):
|
class ModelParser(Subcommand):
|
|
@ -9,7 +9,7 @@ import textwrap
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
class ModelTemplate(Subcommand):
|
class ModelTemplate(Subcommand):
|
||||||
|
@ -75,7 +75,7 @@ class ModelTemplate(Subcommand):
|
||||||
render_jinja_template,
|
render_jinja_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_toolchain.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
|
|
||||||
if args.name:
|
if args.name:
|
||||||
tool_prompt_format = self._prompt_type(args.format)
|
tool_prompt_format = self._prompt_type(args.format)
|
|
@ -6,8 +6,8 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -29,7 +29,7 @@ class StackBuild(Subcommand):
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"config",
|
"config",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to a config file to use for the build. You may find example configs in llama_toolchain/configs/distributions",
|
help="Path to a config file to use for the build. You may find example configs in llama_stack/distribution/example_configs",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
|
@ -44,17 +44,17 @@ class StackBuild(Subcommand):
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_toolchain.common.serialize import EnumEncoder
|
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||||
from llama_toolchain.core.package import ApiInput, build_package, ImageType
|
from llama_stack.distribution.build import ApiInput, build_image, ImageType
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
# save build.yaml spec for building same distribution again
|
# save build.yaml spec for building same distribution again
|
||||||
if build_config.image_type == ImageType.docker.value:
|
if build_config.image_type == ImageType.docker.value:
|
||||||
# docker needs build file to be in the llama-stack repo dir to be able to copy over to the image
|
# docker needs build file to be in the llama-stack repo dir to be able to copy over to the image
|
||||||
llama_toolchain_path = Path(os.path.relpath(__file__)).parent.parent.parent
|
llama_stack_path = Path(os.path.relpath(__file__)).parent.parent.parent
|
||||||
build_dir = (
|
build_dir = (
|
||||||
llama_toolchain_path / "configs/distributions" / build_config.image_type
|
llama_stack_path / "configs/distributions" / build_config.image_type
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
build_dir = DISTRIBS_BASE_DIR / build_config.image_type
|
build_dir = DISTRIBS_BASE_DIR / build_config.image_type
|
||||||
|
@ -66,7 +66,7 @@ class StackBuild(Subcommand):
|
||||||
to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder))
|
to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder))
|
||||||
f.write(yaml.dump(to_write, sort_keys=False))
|
f.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
|
||||||
build_package(build_config, build_file_path)
|
build_image(build_config, build_file_path)
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Build spec configuration saved at {str(build_file_path)}",
|
f"Build spec configuration saved at {str(build_file_path)}",
|
||||||
|
@ -74,12 +74,12 @@ class StackBuild(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||||
from llama_toolchain.core.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
if not args.config:
|
if not args.config:
|
||||||
self.parser.error(
|
self.parser.error(
|
||||||
"No config file specified. Please use `llama stack build /path/to/*-build.yaml`. Example config files can be found in llama_toolchain/configs/distributions"
|
"No config file specified. Please use `llama stack build /path/to/*-build.yaml`. Example config files can be found in llama_stack/distribution/example_configs"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
|
@ -13,11 +13,11 @@ import pkg_resources
|
||||||
import yaml
|
import yaml
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
||||||
|
|
||||||
from llama_toolchain.common.exec import run_with_pty
|
from llama_stack.distribution.utils.exec import run_with_pty
|
||||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ class StackConfigure(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
|
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_toolchain.core.package import ImageType
|
from llama_stack.distribution.build import ImageType
|
||||||
|
|
||||||
docker_image = None
|
docker_image = None
|
||||||
build_config_file = Path(args.config)
|
build_config_file = Path(args.config)
|
||||||
|
@ -66,7 +66,7 @@ class StackConfigure(Subcommand):
|
||||||
os.makedirs(builds_dir, exist_ok=True)
|
os.makedirs(builds_dir, exist_ok=True)
|
||||||
|
|
||||||
script = pkg_resources.resource_filename(
|
script = pkg_resources.resource_filename(
|
||||||
"llama_toolchain", "core/configure_container.sh"
|
"llama_stack", "distribution/configure_container.sh"
|
||||||
)
|
)
|
||||||
script_args = [script, docker_image, str(builds_dir)]
|
script_args = [script, docker_image, str(builds_dir)]
|
||||||
|
|
||||||
|
@ -95,8 +95,8 @@ class StackConfigure(Subcommand):
|
||||||
build_config: BuildConfig,
|
build_config: BuildConfig,
|
||||||
output_dir: Optional[str] = None,
|
output_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
from llama_toolchain.common.serialize import EnumEncoder
|
from llama_stack.distribution.configure import configure_api_providers
|
||||||
from llama_toolchain.core.configure import configure_api_providers
|
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||||
|
|
||||||
builds_dir = BUILDS_BASE_DIR / build_config.image_type
|
builds_dir = BUILDS_BASE_DIR / build_config.image_type
|
||||||
if output_dir:
|
if output_dir:
|
||||||
|
@ -105,16 +105,9 @@ class StackConfigure(Subcommand):
|
||||||
image_name = build_config.name.replace("::", "-")
|
image_name = build_config.name.replace("::", "-")
|
||||||
run_config_file = builds_dir / f"{image_name}-run.yaml"
|
run_config_file = builds_dir / f"{image_name}-run.yaml"
|
||||||
|
|
||||||
api2providers = build_config.distribution_spec.providers
|
|
||||||
|
|
||||||
stub_config = {
|
|
||||||
api_str: {"provider_id": provider}
|
|
||||||
for api_str, provider in api2providers.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
if run_config_file.exists():
|
if run_config_file.exists():
|
||||||
cprint(
|
cprint(
|
||||||
f"Configuration already exists for {build_config.name}. Will overwrite...",
|
f"Configuration already exists at `{str(run_config_file)}`. Will overwrite...",
|
||||||
"yellow",
|
"yellow",
|
||||||
attrs=["bold"],
|
attrs=["bold"],
|
||||||
)
|
)
|
||||||
|
@ -123,10 +116,12 @@ class StackConfigure(Subcommand):
|
||||||
config = StackRunConfig(
|
config = StackRunConfig(
|
||||||
built_at=datetime.now(),
|
built_at=datetime.now(),
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
providers=stub_config,
|
apis_to_serve=[],
|
||||||
|
provider_map={},
|
||||||
)
|
)
|
||||||
|
|
||||||
config.providers = configure_api_providers(config.providers)
|
config = configure_api_providers(config, build_config.distribution_spec)
|
||||||
|
|
||||||
config.docker_image = (
|
config.docker_image = (
|
||||||
image_name if build_config.image_type == "docker" else None
|
image_name if build_config.image_type == "docker" else None
|
||||||
)
|
)
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
class StackListApis(Subcommand):
|
class StackListApis(Subcommand):
|
||||||
|
@ -25,8 +25,8 @@ class StackListApis(Subcommand):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _run_apis_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_apis_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_toolchain.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
from llama_toolchain.core.distribution import stack_apis
|
from llama_stack.distribution.distribution import stack_apis
|
||||||
|
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
headers = [
|
headers = [
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
class StackListProviders(Subcommand):
|
class StackListProviders(Subcommand):
|
||||||
|
@ -22,7 +22,7 @@ class StackListProviders(Subcommand):
|
||||||
self.parser.set_defaults(func=self._run_providers_list_cmd)
|
self.parser.set_defaults(func=self._run_providers_list_cmd)
|
||||||
|
|
||||||
def _add_arguments(self):
|
def _add_arguments(self):
|
||||||
from llama_toolchain.core.distribution import stack_apis
|
from llama_stack.distribution.distribution import stack_apis
|
||||||
|
|
||||||
api_values = [a.value for a in stack_apis()]
|
api_values = [a.value for a in stack_apis()]
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
|
@ -33,8 +33,8 @@ class StackListProviders(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_toolchain.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
from llama_toolchain.core.distribution import Api, api_providers
|
from llama_stack.distribution.distribution import Api, api_providers
|
||||||
|
|
||||||
all_providers = api_providers()
|
all_providers = api_providers()
|
||||||
providers_for_api = all_providers[Api(args.api)]
|
providers_for_api = all_providers[Api(args.api)]
|
|
@ -11,8 +11,8 @@ from pathlib import Path
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class StackRun(Subcommand):
|
class StackRun(Subcommand):
|
||||||
|
@ -47,7 +47,7 @@ class StackRun(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_toolchain.common.exec import run_with_pty
|
from llama_stack.distribution.utils.exec import run_with_pty
|
||||||
|
|
||||||
if not args.config:
|
if not args.config:
|
||||||
self.parser.error("Must specify a config file to run")
|
self.parser.error("Must specify a config file to run")
|
||||||
|
@ -67,14 +67,14 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
if config.docker_image:
|
if config.docker_image:
|
||||||
script = pkg_resources.resource_filename(
|
script = pkg_resources.resource_filename(
|
||||||
"llama_toolchain",
|
"llama_stack",
|
||||||
"core/start_container.sh",
|
"distribution/start_container.sh",
|
||||||
)
|
)
|
||||||
run_args = [script, config.docker_image]
|
run_args = [script, config.docker_image]
|
||||||
else:
|
else:
|
||||||
script = pkg_resources.resource_filename(
|
script = pkg_resources.resource_filename(
|
||||||
"llama_toolchain",
|
"llama_stack",
|
||||||
"core/start_conda_env.sh",
|
"distribution/start_conda_env.sh",
|
||||||
)
|
)
|
||||||
run_args = [
|
run_args = [
|
||||||
script,
|
script,
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
from .build import StackBuild
|
from .build import StackBuild
|
||||||
from .configure import StackConfigure
|
from .configure import StackConfigure
|
|
@ -4,26 +4,20 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import yaml
|
|
||||||
|
|
||||||
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
|
|
||||||
from llama_toolchain.common.exec import run_with_pty
|
|
||||||
from llama_toolchain.common.serialize import EnumEncoder
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
from llama_stack.distribution.utils.exec import run_with_pty
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_toolchain.core.distribution import api_providers, SERVER_DEPENDENCIES
|
from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES
|
||||||
|
|
||||||
|
|
||||||
class ImageType(Enum):
|
class ImageType(Enum):
|
||||||
|
@ -41,7 +35,7 @@ class ApiInput(BaseModel):
|
||||||
provider: str
|
provider: str
|
||||||
|
|
||||||
|
|
||||||
def build_package(build_config: BuildConfig, build_file_path: Path):
|
def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||||
package_deps = Dependencies(
|
package_deps = Dependencies(
|
||||||
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
|
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
|
||||||
pip_packages=SERVER_DEPENDENCIES,
|
pip_packages=SERVER_DEPENDENCIES,
|
||||||
|
@ -49,21 +43,32 @@ def build_package(build_config: BuildConfig, build_file_path: Path):
|
||||||
|
|
||||||
# extend package dependencies based on providers spec
|
# extend package dependencies based on providers spec
|
||||||
all_providers = api_providers()
|
all_providers = api_providers()
|
||||||
for api_str, provider in build_config.distribution_spec.providers.items():
|
for (
|
||||||
|
api_str,
|
||||||
|
provider_or_providers,
|
||||||
|
) in build_config.distribution_spec.providers.items():
|
||||||
providers_for_api = all_providers[Api(api_str)]
|
providers_for_api = all_providers[Api(api_str)]
|
||||||
if provider not in providers_for_api:
|
|
||||||
raise ValueError(
|
|
||||||
f"Provider `{provider}` is not available for API `{api_str}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
provider_spec = providers_for_api[provider]
|
providers = (
|
||||||
package_deps.pip_packages.extend(provider_spec.pip_packages)
|
provider_or_providers
|
||||||
if provider_spec.docker_image:
|
if isinstance(provider_or_providers, list)
|
||||||
raise ValueError("A stack's dependencies cannot have a docker image")
|
else [provider_or_providers]
|
||||||
|
)
|
||||||
|
|
||||||
|
for provider in providers:
|
||||||
|
if provider not in providers_for_api:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider `{provider}` is not available for API `{api_str}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_spec = providers_for_api[provider]
|
||||||
|
package_deps.pip_packages.extend(provider_spec.pip_packages)
|
||||||
|
if provider_spec.docker_image:
|
||||||
|
raise ValueError("A stack's dependencies cannot have a docker image")
|
||||||
|
|
||||||
if build_config.image_type == ImageType.docker.value:
|
if build_config.image_type == ImageType.docker.value:
|
||||||
script = pkg_resources.resource_filename(
|
script = pkg_resources.resource_filename(
|
||||||
"llama_toolchain", "core/build_container.sh"
|
"llama_stack", "distribution/build_container.sh"
|
||||||
)
|
)
|
||||||
args = [
|
args = [
|
||||||
script,
|
script,
|
||||||
|
@ -74,7 +79,7 @@ def build_package(build_config: BuildConfig, build_file_path: Path):
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
script = pkg_resources.resource_filename(
|
script = pkg_resources.resource_filename(
|
||||||
"llama_toolchain", "core/build_conda_env.sh"
|
"llama_stack", "distribution/build_conda_env.sh"
|
||||||
)
|
)
|
||||||
args = [
|
args = [
|
||||||
script,
|
script,
|
|
@ -7,11 +7,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||||
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
|
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
|
|
||||||
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
echo "Using llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR"
|
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||||
fi
|
fi
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||||
|
@ -78,19 +78,19 @@ ensure_conda_env_python310() {
|
||||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||||
# these packages are damaged in test-pypi, so install them first
|
# these packages are damaged in test-pypi, so install them first
|
||||||
pip install fastapi libcst
|
pip install fastapi libcst
|
||||||
pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-toolchain==$TEST_PYPI_VERSION $pip_dependencies
|
pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION $pip_dependencies
|
||||||
else
|
else
|
||||||
# Re-installing llama-toolchain in the new conda environment
|
# Re-installing llama-stack in the new conda environment
|
||||||
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
|
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||||
printf "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}\n" >&2
|
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
printf "Installing from LLAMA_TOOLCHAIN_DIR: $LLAMA_TOOLCHAIN_DIR\n"
|
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
|
||||||
pip install --no-cache-dir -e "$LLAMA_TOOLCHAIN_DIR"
|
pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
|
||||||
else
|
else
|
||||||
pip install --no-cache-dir llama-toolchain
|
pip install --no-cache-dir llama-stack
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|
@ -1,7 +1,7 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||||
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
|
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
|
|
||||||
if [ "$#" -ne 4 ]; then
|
if [ "$#" -ne 4 ]; then
|
||||||
|
@ -55,17 +55,17 @@ RUN apt-get update && apt-get install -y \
|
||||||
|
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
toolchain_mount="/app/llama-toolchain-source"
|
stack_mount="/app/llama-stack-source"
|
||||||
models_mount="/app/llama-models-source"
|
models_mount="/app/llama-models-source"
|
||||||
|
|
||||||
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
|
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||||
echo "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}" >&2
|
echo "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
add_to_docker "RUN pip install $toolchain_mount"
|
add_to_docker "RUN pip install $stack_mount"
|
||||||
else
|
else
|
||||||
add_to_docker "RUN pip install llama-toolchain"
|
add_to_docker "RUN pip install llama-stack"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
|
@ -90,7 +90,7 @@ add_to_docker <<EOF
|
||||||
# This would be good in production but for debugging flexibility lets not add it right now
|
# This would be good in production but for debugging flexibility lets not add it right now
|
||||||
# We need a more solid production ready entrypoint.sh anyway
|
# We need a more solid production ready entrypoint.sh anyway
|
||||||
#
|
#
|
||||||
# ENTRYPOINT ["python", "-m", "llama_toolchain.core.server"]
|
# ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
|
||||||
|
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
|
@ -101,8 +101,8 @@ cat $TEMP_DIR/Dockerfile
|
||||||
printf "\n"
|
printf "\n"
|
||||||
|
|
||||||
mounts=""
|
mounts=""
|
||||||
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_TOOLCHAIN_DIR):$toolchain_mount"
|
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount"
|
||||||
fi
|
fi
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
110
llama_stack/distribution/configure.py
Normal file
110
llama_stack/distribution/configure.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.distribution.distribution import api_providers, stack_apis
|
||||||
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||||
|
|
||||||
|
|
||||||
|
# These are hacks so we can re-use the `prompt_for_config` utility
|
||||||
|
# This needs a bunch of work to be made very user friendly.
|
||||||
|
class ReqApis(BaseModel):
|
||||||
|
apis_to_serve: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
def make_routing_entry_type(config_class: Any):
|
||||||
|
class BaseModelWithConfig(BaseModel):
|
||||||
|
routing_key: str
|
||||||
|
config: config_class
|
||||||
|
|
||||||
|
return BaseModelWithConfig
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: make sure we can deal with existing configuration values correctly
|
||||||
|
# instead of just overwriting them
|
||||||
|
def configure_api_providers(
|
||||||
|
config: StackRunConfig, spec: DistributionSpec
|
||||||
|
) -> StackRunConfig:
|
||||||
|
cprint("Configuring APIs to serve...", "white", attrs=["bold"])
|
||||||
|
print("Enter comma-separated list of APIs to serve:")
|
||||||
|
|
||||||
|
apis = config.apis_to_serve or list(spec.providers.keys())
|
||||||
|
apis = [a for a in apis if a != "telemetry"]
|
||||||
|
req_apis = ReqApis(
|
||||||
|
apis_to_serve=apis,
|
||||||
|
)
|
||||||
|
req_apis = prompt_for_config(ReqApis, req_apis)
|
||||||
|
config.apis_to_serve = req_apis.apis_to_serve
|
||||||
|
print("")
|
||||||
|
|
||||||
|
apis = [v.value for v in stack_apis()]
|
||||||
|
all_providers = api_providers()
|
||||||
|
|
||||||
|
for api_str in spec.providers.keys():
|
||||||
|
if api_str not in apis:
|
||||||
|
raise ValueError(f"Unknown API `{api_str}`")
|
||||||
|
|
||||||
|
cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"])
|
||||||
|
api = Api(api_str)
|
||||||
|
|
||||||
|
provider_or_providers = spec.providers[api_str]
|
||||||
|
if isinstance(provider_or_providers, list) and len(provider_or_providers) > 1:
|
||||||
|
print(
|
||||||
|
"You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
routing_entries = []
|
||||||
|
for p in provider_or_providers:
|
||||||
|
print(f"Configuring provider `{p}`...")
|
||||||
|
provider_spec = all_providers[api][p]
|
||||||
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
|
||||||
|
# TODO: we need to validate the routing keys, and
|
||||||
|
# perhaps it is better if we break this out into asking
|
||||||
|
# for a routing key separately from the associated config
|
||||||
|
wrapper_type = make_routing_entry_type(config_type)
|
||||||
|
rt_entry = prompt_for_config(wrapper_type, None)
|
||||||
|
|
||||||
|
routing_entries.append(
|
||||||
|
ProviderRoutingEntry(
|
||||||
|
provider_id=p,
|
||||||
|
routing_key=rt_entry.routing_key,
|
||||||
|
config=rt_entry.config.dict(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
config.provider_map[api_str] = routing_entries
|
||||||
|
else:
|
||||||
|
p = (
|
||||||
|
provider_or_providers[0]
|
||||||
|
if isinstance(provider_or_providers, list)
|
||||||
|
else provider_or_providers
|
||||||
|
)
|
||||||
|
print(f"Configuring provider `{p}`...")
|
||||||
|
provider_spec = all_providers[api][p]
|
||||||
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
try:
|
||||||
|
provider_config = config.provider_map.get(api_str)
|
||||||
|
if provider_config:
|
||||||
|
existing = config_type(**provider_config.config)
|
||||||
|
else:
|
||||||
|
existing = None
|
||||||
|
except Exception:
|
||||||
|
existing = None
|
||||||
|
cfg = prompt_for_config(config_type, existing)
|
||||||
|
config.provider_map[api_str] = GenericProviderConfig(
|
||||||
|
provider_id=p,
|
||||||
|
config=cfg.dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return config
|
7
llama_stack/distribution/control_plane/__init__.py
Normal file
7
llama_stack/distribution/control_plane/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .control_plane import * # noqa: F401 F403
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import RedisImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: RedisImplConfig, _deps):
|
||||||
|
from .redis import RedisControlPlaneAdapter
|
||||||
|
|
||||||
|
impl = RedisControlPlaneAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RedisImplConfig(BaseModel):
|
||||||
|
url: str = Field(
|
||||||
|
description="The URL for the Redis server",
|
||||||
|
)
|
||||||
|
namespace: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="All keys will be prefixed with this namespace",
|
||||||
|
)
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from llama_stack.apis.control_plane import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
from .config import RedisImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
class RedisControlPlaneAdapter(ControlPlane):
|
||||||
|
def __init__(self, config: RedisImplConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self.redis = Redis.from_url(self.config.url)
|
||||||
|
|
||||||
|
def _namespaced_key(self, key: str) -> str:
|
||||||
|
if not self.config.namespace:
|
||||||
|
return key
|
||||||
|
return f"{self.config.namespace}:{key}"
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self, key: str, value: Any, expiration: Optional[datetime] = None
|
||||||
|
) -> None:
|
||||||
|
key = self._namespaced_key(key)
|
||||||
|
await self.redis.set(key, value)
|
||||||
|
if expiration:
|
||||||
|
await self.redis.expireat(key, expiration)
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Optional[ControlPlaneValue]:
|
||||||
|
key = self._namespaced_key(key)
|
||||||
|
value = await self.redis.get(key)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
ttl = await self.redis.ttl(key)
|
||||||
|
expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None
|
||||||
|
return ControlPlaneValue(key=key, value=value, expiration=expiration)
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> None:
|
||||||
|
key = self._namespaced_key(key)
|
||||||
|
await self.redis.delete(key)
|
||||||
|
|
||||||
|
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]:
|
||||||
|
start_key = self._namespaced_key(start_key)
|
||||||
|
end_key = self._namespaced_key(end_key)
|
||||||
|
|
||||||
|
keys = await self.redis.keys(f"{start_key}*")
|
||||||
|
result = []
|
||||||
|
for key in keys:
|
||||||
|
if key <= end_key:
|
||||||
|
value = await self.get(key)
|
||||||
|
if value:
|
||||||
|
result.append(value)
|
||||||
|
return result
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import SqliteControlPlaneConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: SqliteControlPlaneConfig, _deps):
|
||||||
|
from .control_plane import SqliteControlPlane
|
||||||
|
|
||||||
|
impl = SqliteControlPlane(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SqliteControlPlaneConfig(BaseModel):
|
||||||
|
db_path: str = Field(
|
||||||
|
description="File path for the sqlite database",
|
||||||
|
)
|
||||||
|
table_name: str = Field(
|
||||||
|
default="llamastack_control_plane",
|
||||||
|
description="Table into which all the keys will be placed",
|
||||||
|
)
|
|
@ -0,0 +1,79 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from llama_stack.apis.control_plane import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
from .config import SqliteControlPlaneConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteControlPlane(ControlPlane):
|
||||||
|
def __init__(self, config: SqliteControlPlaneConfig):
|
||||||
|
self.db_path = config.db_path
|
||||||
|
self.table_name = config.table_name
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
await db.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT,
|
||||||
|
expiration TIMESTAMP
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self, key: str, value: Any, expiration: Optional[datetime] = None
|
||||||
|
) -> None:
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
await db.execute(
|
||||||
|
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||||
|
(key, json.dumps(value), expiration),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Optional[ControlPlaneValue]:
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
async with db.execute(
|
||||||
|
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
value, expiration = row
|
||||||
|
return ControlPlaneValue(
|
||||||
|
key=key, value=json.loads(value), expiration=expiration
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> None:
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]:
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
async with db.execute(
|
||||||
|
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||||
|
(start_key, end_key),
|
||||||
|
) as cursor:
|
||||||
|
result = []
|
||||||
|
async for row in cursor:
|
||||||
|
key, value, expiration = row
|
||||||
|
result.append(
|
||||||
|
ControlPlaneValue(
|
||||||
|
key=key, value=json.loads(value), expiration=expiration
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
35
llama_stack/distribution/control_plane/api.py
Normal file
35
llama_stack/distribution/control_plane/api.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, List, Optional, Protocol
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ControlPlaneValue(BaseModel):
|
||||||
|
key: str
|
||||||
|
value: Any
|
||||||
|
expiration: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ControlPlane(Protocol):
|
||||||
|
@webmethod(route="/control_plane/set")
|
||||||
|
async def set(
|
||||||
|
self, key: str, value: Any, expiration: Optional[datetime] = None
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/control_plane/get", method="GET")
|
||||||
|
async def get(self, key: str) -> Optional[ControlPlaneValue]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/control_plane/delete")
|
||||||
|
async def delete(self, key: str) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/control_plane/range", method="GET")
|
||||||
|
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: ...
|
29
llama_stack/distribution/control_plane/registry.py
Normal file
29
llama_stack/distribution/control_plane/registry.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> List[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.control_plane,
|
||||||
|
provider_id="sqlite",
|
||||||
|
pip_packages=["aiosqlite"],
|
||||||
|
module="llama_stack.providers.impls.sqlite.control_plane",
|
||||||
|
config_class="llama_stack.providers.impls.sqlite.control_plane.SqliteControlPlaneConfig",
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
Api.control_plane,
|
||||||
|
AdapterSpec(
|
||||||
|
adapter_id="redis",
|
||||||
|
pip_packages=["redis"],
|
||||||
|
module="llama_stack.providers.adapters.control_plane.redis",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from pydantic import BaseModel, Field, validator
|
||||||
class Api(Enum):
|
class Api(Enum):
|
||||||
inference = "inference"
|
inference = "inference"
|
||||||
safety = "safety"
|
safety = "safety"
|
||||||
agentic_system = "agentic_system"
|
agents = "agents"
|
||||||
memory = "memory"
|
memory = "memory"
|
||||||
telemetry = "telemetry"
|
telemetry = "telemetry"
|
||||||
|
|
||||||
|
@ -43,6 +43,33 @@ class ProviderSpec(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RouterProviderSpec(ProviderSpec):
|
||||||
|
provider_id: str = "router"
|
||||||
|
config_class: str = ""
|
||||||
|
|
||||||
|
docker_image: Optional[str] = None
|
||||||
|
|
||||||
|
inner_specs: List[ProviderSpec]
|
||||||
|
module: str = Field(
|
||||||
|
...,
|
||||||
|
description="""
|
||||||
|
Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
|
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pip_packages(self) -> List[str]:
|
||||||
|
raise AssertionError("Should not be called on RouterProviderSpec")
|
||||||
|
|
||||||
|
|
||||||
|
class GenericProviderConfig(BaseModel):
|
||||||
|
provider_id: str
|
||||||
|
config: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AdapterSpec(BaseModel):
|
class AdapterSpec(BaseModel):
|
||||||
adapter_id: str = Field(
|
adapter_id: str = Field(
|
||||||
|
@ -124,7 +151,7 @@ as being "Llama Stack compatible"
|
||||||
def module(self) -> str:
|
def module(self) -> str:
|
||||||
if self.adapter:
|
if self.adapter:
|
||||||
return self.adapter.module
|
return self.adapter.module
|
||||||
return f"llama_toolchain.{self.api.value}.client"
|
return f"llama_stack.apis.{self.api.value}.client"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> List[str]:
|
def pip_packages(self) -> List[str]:
|
||||||
|
@ -140,7 +167,7 @@ def remote_provider_spec(
|
||||||
config_class = (
|
config_class = (
|
||||||
adapter.config_class
|
adapter.config_class
|
||||||
if adapter and adapter.config_class
|
if adapter and adapter.config_class
|
||||||
else "llama_toolchain.core.datatypes.RemoteProviderConfig"
|
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
||||||
)
|
)
|
||||||
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
|
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
|
||||||
|
|
||||||
|
@ -156,12 +183,23 @@ class DistributionSpec(BaseModel):
|
||||||
description="Description of the distribution",
|
description="Description of the distribution",
|
||||||
)
|
)
|
||||||
docker_image: Optional[str] = None
|
docker_image: Optional[str] = None
|
||||||
providers: Dict[str, str] = Field(
|
providers: Dict[str, Union[str, List[str]]] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Provider Types for each of the APIs provided by this distribution",
|
description="""
|
||||||
|
Provider Types for each of the APIs provided by this distribution. If you
|
||||||
|
select multiple providers, you should provide an appropriate 'routing_map'
|
||||||
|
in the runtime configuration to help route to the correct provider.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ProviderRoutingEntry(GenericProviderConfig):
|
||||||
|
routing_key: str
|
||||||
|
|
||||||
|
|
||||||
|
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
built_at: datetime
|
built_at: datetime
|
||||||
|
@ -181,12 +219,22 @@ this could be just a hash
|
||||||
default=None,
|
default=None,
|
||||||
description="Reference to the conda environment if this package refers to a conda environment",
|
description="Reference to the conda environment if this package refers to a conda environment",
|
||||||
)
|
)
|
||||||
providers: Dict[str, Any] = Field(
|
apis_to_serve: List[str] = Field(
|
||||||
default_factory=dict,
|
|
||||||
description="""
|
description="""
|
||||||
Provider configurations for each of the APIs provided by this package. This includes configurations for
|
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||||
the dependencies of these providers as well.
|
)
|
||||||
""",
|
provider_map: Dict[str, ProviderMapEntry] = Field(
|
||||||
|
description="""
|
||||||
|
Provider configurations for each of the APIs provided by this package.
|
||||||
|
|
||||||
|
Given an API, you can specify a single provider or a "routing table". Each entry in the routing
|
||||||
|
table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific.
|
||||||
|
|
||||||
|
As examples:
|
||||||
|
- the "inference" API interprets the routing_key as a "model"
|
||||||
|
- the "memory" API interprets the routing_key as the type of a "memory bank"
|
||||||
|
|
||||||
|
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -8,18 +8,19 @@ import importlib
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api import AgenticSystem
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_toolchain.inference.api import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_toolchain.memory.api import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
from llama_toolchain.safety.api import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_toolchain.telemetry.api import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
|
|
||||||
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
|
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
|
||||||
|
|
||||||
# These are the dependencies needed by the distribution server.
|
# These are the dependencies needed by the distribution server.
|
||||||
# `llama-toolchain` is automatically installed by the installation script.
|
# `llama-stack` is automatically installed by the installation script.
|
||||||
SERVER_DEPENDENCIES = [
|
SERVER_DEPENDENCIES = [
|
||||||
"fastapi",
|
"fastapi",
|
||||||
|
"fire",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -34,7 +35,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
protocols = {
|
protocols = {
|
||||||
Api.inference: Inference,
|
Api.inference: Inference,
|
||||||
Api.safety: Safety,
|
Api.safety: Safety,
|
||||||
Api.agentic_system: AgenticSystem,
|
Api.agents: Agents,
|
||||||
Api.memory: Memory,
|
Api.memory: Memory,
|
||||||
Api.telemetry: Telemetry,
|
Api.telemetry: Telemetry,
|
||||||
}
|
}
|
||||||
|
@ -67,7 +68,7 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
ret = {}
|
ret = {}
|
||||||
for api in stack_apis():
|
for api in stack_apis():
|
||||||
name = api.name.lower()
|
name = api.name.lower()
|
||||||
module = importlib.import_module(f"llama_toolchain.{name}.providers")
|
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||||
ret[api] = {
|
ret[api] = {
|
||||||
"remote": remote_provider_spec(api),
|
"remote": remote_provider_spec(api),
|
||||||
**{a.provider_id: a for a in module.available_providers()},
|
**{a.provider_id: a for a in module.available_providers()},
|
|
@ -0,0 +1,10 @@
|
||||||
|
name: local-conda-example
|
||||||
|
distribution_spec:
|
||||||
|
description: Use code from `llama_stack` itself to serve all llama stack APIs
|
||||||
|
providers:
|
||||||
|
inference: meta-reference
|
||||||
|
memory: meta-reference
|
||||||
|
safety: meta-reference
|
||||||
|
agents: meta-reference
|
||||||
|
telemetry: meta-reference
|
||||||
|
image_type: conda
|
|
@ -3,8 +3,8 @@ distribution_spec:
|
||||||
description: Use Fireworks.ai for running LLM inference
|
description: Use Fireworks.ai for running LLM inference
|
||||||
providers:
|
providers:
|
||||||
inference: remote::fireworks
|
inference: remote::fireworks
|
||||||
memory: meta-reference-faiss
|
memory: meta-reference
|
||||||
safety: meta-reference
|
safety: meta-reference
|
||||||
agentic_system: meta-reference
|
agents: meta-reference
|
||||||
telemetry: console
|
telemetry: meta-reference
|
||||||
image_type: conda
|
image_type: conda
|
|
@ -3,8 +3,8 @@ distribution_spec:
|
||||||
description: Like local, but use ollama for running LLM inference
|
description: Like local, but use ollama for running LLM inference
|
||||||
providers:
|
providers:
|
||||||
inference: remote::ollama
|
inference: remote::ollama
|
||||||
memory: meta-reference-faiss
|
memory: meta-reference
|
||||||
safety: meta-reference
|
safety: meta-reference
|
||||||
agentic_system: meta-reference
|
agents: meta-reference
|
||||||
telemetry: console
|
telemetry: meta-reference
|
||||||
image_type: conda
|
image_type: conda
|
|
@ -3,8 +3,8 @@ distribution_spec:
|
||||||
description: Use TGI (local or with Hugging Face Inference Endpoints for running LLM inference. When using HF Inference Endpoints, you must provide the name of the endpoint).
|
description: Use TGI (local or with Hugging Face Inference Endpoints for running LLM inference. When using HF Inference Endpoints, you must provide the name of the endpoint).
|
||||||
providers:
|
providers:
|
||||||
inference: remote::tgi
|
inference: remote::tgi
|
||||||
memory: meta-reference-faiss
|
memory: meta-reference
|
||||||
safety: meta-reference
|
safety: meta-reference
|
||||||
agentic_system: meta-reference
|
agents: meta-reference
|
||||||
telemetry: console
|
telemetry: meta-reference
|
||||||
image_type: conda
|
image_type: conda
|
|
@ -3,8 +3,8 @@ distribution_spec:
|
||||||
description: Use Together.ai for running LLM inference
|
description: Use Together.ai for running LLM inference
|
||||||
providers:
|
providers:
|
||||||
inference: remote::together
|
inference: remote::together
|
||||||
memory: meta-reference-faiss
|
memory: meta-reference
|
||||||
safety: meta-reference
|
safety: meta-reference
|
||||||
agentic_system: meta-reference
|
agents: meta-reference
|
||||||
telemetry: console
|
telemetry: meta-reference
|
||||||
image_type: conda
|
image_type: conda
|
|
@ -0,0 +1,10 @@
|
||||||
|
name: local-docker-example
|
||||||
|
distribution_spec:
|
||||||
|
description: Use code from `llama_stack` itself to serve all llama stack APIs
|
||||||
|
providers:
|
||||||
|
inference: meta-reference
|
||||||
|
memory: meta-reference
|
||||||
|
safety: meta-reference
|
||||||
|
agents: meta-reference
|
||||||
|
telemetry: meta-reference
|
||||||
|
image_type: docker
|
|
@ -9,6 +9,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import signal
|
import signal
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from collections.abc import (
|
from collections.abc import (
|
||||||
AsyncGenerator as AsyncGeneratorABC,
|
AsyncGenerator as AsyncGeneratorABC,
|
||||||
AsyncIterator as AsyncIteratorABC,
|
AsyncIterator as AsyncIteratorABC,
|
||||||
|
@ -38,16 +39,16 @@ from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_toolchain.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
SpanStatus,
|
SpanStatus,
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||||
from .distribution import api_endpoints, api_providers
|
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||||
from .dynamic import instantiate_provider
|
|
||||||
|
|
||||||
|
|
||||||
def is_async_iterator_type(typ):
|
def is_async_iterator_type(typ):
|
||||||
|
@ -271,61 +272,80 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||||
return [by_id[x] for x in stack]
|
return [by_id[x] for x in stack]
|
||||||
|
|
||||||
|
|
||||||
def resolve_impls(
|
def snake_to_camel(snake_str):
|
||||||
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
|
return "".join(word.capitalize() for word in snake_str.split("_"))
|
||||||
) -> Dict[Api, Any]:
|
|
||||||
provider_configs = config["providers"]
|
|
||||||
provider_specs = topological_sort(provider_specs.values())
|
|
||||||
|
|
||||||
impls = {}
|
|
||||||
for provider_spec in provider_specs:
|
async def resolve_impls(
|
||||||
api = provider_spec.api
|
provider_map: Dict[str, ProviderMapEntry],
|
||||||
if api.value not in provider_configs:
|
) -> Dict[Api, Any]:
|
||||||
raise ValueError(
|
"""
|
||||||
f"Could not find provider_spec config for {api}. Please add it to the config"
|
Does two things:
|
||||||
|
- flatmaps, sorts and resolves the providers in dependency order
|
||||||
|
- for each API, produces either a (local, passthrough or router) implementation
|
||||||
|
"""
|
||||||
|
all_providers = api_providers()
|
||||||
|
|
||||||
|
specs = {}
|
||||||
|
for api_str, item in provider_map.items():
|
||||||
|
api = Api(api_str)
|
||||||
|
providers = all_providers[api]
|
||||||
|
|
||||||
|
if isinstance(item, GenericProviderConfig):
|
||||||
|
if item.provider_id not in providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
|
specs[api] = providers[item.provider_id]
|
||||||
|
else:
|
||||||
|
assert isinstance(item, list)
|
||||||
|
inner_specs = []
|
||||||
|
for rt_entry in item:
|
||||||
|
if rt_entry.provider_id not in providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
|
inner_specs.append(providers[rt_entry.provider_id])
|
||||||
|
|
||||||
|
specs[api] = RouterProviderSpec(
|
||||||
|
api=api,
|
||||||
|
module=f"llama_stack.providers.routers.{api.value.lower()}",
|
||||||
|
api_dependencies=[],
|
||||||
|
inner_specs=inner_specs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(provider_spec, InlineProviderSpec):
|
sorted_specs = topological_sort(specs.values())
|
||||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
|
||||||
else:
|
impls = {}
|
||||||
deps = {}
|
for spec in sorted_specs:
|
||||||
provider_config = provider_configs[api.value]
|
api = spec.api
|
||||||
impl = instantiate_provider(provider_spec, provider_config, deps)
|
|
||||||
|
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||||
|
impl = await instantiate_provider(spec, deps, provider_map[api.value])
|
||||||
impls[api] = impl
|
impls[api] = impl
|
||||||
|
|
||||||
return impls
|
return impls, specs
|
||||||
|
|
||||||
|
|
||||||
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
with open(yaml_config, "r") as fp:
|
with open(yaml_config, "r") as fp:
|
||||||
config = yaml.safe_load(fp)
|
config = StackRunConfig(**yaml.safe_load(fp))
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
all_endpoints = api_endpoints()
|
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||||
all_providers = api_providers()
|
|
||||||
|
|
||||||
provider_specs = {}
|
|
||||||
for api_str, provider_config in config["providers"].items():
|
|
||||||
api = Api(api_str)
|
|
||||||
providers = all_providers[api]
|
|
||||||
provider_id = provider_config["provider_id"]
|
|
||||||
if provider_id not in providers:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
provider_specs[api] = providers[provider_id]
|
|
||||||
|
|
||||||
impls = resolve_impls(provider_specs, config)
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
for provider_spec in provider_specs.values():
|
all_endpoints = api_endpoints()
|
||||||
api = provider_spec.api
|
|
||||||
|
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
|
||||||
|
for api_str in apis_to_serve:
|
||||||
|
api = Api(api_str)
|
||||||
endpoints = all_endpoints[api]
|
endpoints = all_endpoints[api]
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
|
|
||||||
|
provider_spec = specs[api]
|
||||||
if (
|
if (
|
||||||
isinstance(provider_spec, RemoteProviderSpec)
|
isinstance(provider_spec, RemoteProviderSpec)
|
||||||
and provider_spec.adapter is None
|
and provider_spec.adapter is None
|
|
@ -37,6 +37,6 @@ eval "$(conda shell.bash hook)"
|
||||||
conda deactivate && conda activate "$env_name"
|
conda deactivate && conda activate "$env_name"
|
||||||
|
|
||||||
$CONDA_PREFIX/bin/python \
|
$CONDA_PREFIX/bin/python \
|
||||||
-m llama_toolchain.core.server \
|
-m llama_stack.distribution.server.server \
|
||||||
--yaml_config "$yaml_config" \
|
--yaml_config "$yaml_config" \
|
||||||
--port "$port" "$@"
|
--port "$port" "$@"
|
|
@ -38,6 +38,6 @@ podman run -it \
|
||||||
-p $port:$port \
|
-p $port:$port \
|
||||||
-v "$yaml_config:/app/config.yaml" \
|
-v "$yaml_config:/app/config.yaml" \
|
||||||
$docker_image \
|
$docker_image \
|
||||||
python -m llama_toolchain.core.server \
|
python -m llama_stack.distribution.server.server \
|
||||||
--yaml_config /app/config.yaml \
|
--yaml_config /app/config.yaml \
|
||||||
--port $port "$@"
|
--port $port "$@"
|
66
llama_stack/distribution/utils/dynamic.py
Normal file
66
llama_stack/distribution/utils/dynamic.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_class_type(fully_qualified_name):
|
||||||
|
module_name, class_name = fully_qualified_name.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
|
async def instantiate_provider(
|
||||||
|
provider_spec: ProviderSpec,
|
||||||
|
deps: Dict[str, Any],
|
||||||
|
provider_config: ProviderMapEntry,
|
||||||
|
):
|
||||||
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
|
||||||
|
args = []
|
||||||
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
|
if provider_spec.adapter:
|
||||||
|
method = "get_adapter_impl"
|
||||||
|
else:
|
||||||
|
method = "get_client_impl"
|
||||||
|
|
||||||
|
assert isinstance(provider_config, GenericProviderConfig)
|
||||||
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
config = config_type(**provider_config.config)
|
||||||
|
args = [config, deps]
|
||||||
|
elif isinstance(provider_spec, RouterProviderSpec):
|
||||||
|
method = "get_router_impl"
|
||||||
|
|
||||||
|
assert isinstance(provider_config, list)
|
||||||
|
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
||||||
|
inner_impls = []
|
||||||
|
for routing_entry in provider_config:
|
||||||
|
impl = await instantiate_provider(
|
||||||
|
inner_specs[routing_entry.provider_id],
|
||||||
|
deps,
|
||||||
|
routing_entry,
|
||||||
|
)
|
||||||
|
inner_impls.append((routing_entry.routing_key, impl))
|
||||||
|
|
||||||
|
config = None
|
||||||
|
args = [inner_impls, deps]
|
||||||
|
else:
|
||||||
|
method = "get_provider_impl"
|
||||||
|
|
||||||
|
assert isinstance(provider_config, GenericProviderConfig)
|
||||||
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
config = config_type(**provider_config.config)
|
||||||
|
args = [config, deps]
|
||||||
|
|
||||||
|
fn = getattr(module, method)
|
||||||
|
impl = await fn(*args)
|
||||||
|
impl.__provider_spec__ = provider_spec
|
||||||
|
impl.__provider_config__ = config
|
||||||
|
return impl
|
|
@ -27,6 +27,12 @@ def is_list_of_primitives(field_type):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_basemodel_without_fields(typ):
|
||||||
|
return (
|
||||||
|
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def can_recurse(typ):
|
def can_recurse(typ):
|
||||||
return (
|
return (
|
||||||
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
||||||
|
@ -151,6 +157,11 @@ def prompt_for_config(
|
||||||
if get_origin(field_type) is Literal:
|
if get_origin(field_type) is Literal:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Skip fields with no type annotations
|
||||||
|
if is_basemodel_without_fields(field_type):
|
||||||
|
config_data[field_name] = field_type()
|
||||||
|
continue
|
||||||
|
|
||||||
if inspect.isclass(field_type) and issubclass(field_type, Enum):
|
if inspect.isclass(field_type) and issubclass(field_type, Enum):
|
||||||
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
|
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
|
||||||
while True:
|
while True:
|
||||||
|
@ -254,6 +265,20 @@ def prompt_for_config(
|
||||||
print(f"{str(e)}")
|
print(f"{str(e)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
elif get_origin(field_type) is dict:
|
||||||
|
try:
|
||||||
|
value = json.loads(user_input)
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise ValueError(
|
||||||
|
"Input must be a JSON-encoded dictionary"
|
||||||
|
)
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(
|
||||||
|
"Invalid JSON. Please enter a valid JSON-encoded dict."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
# Convert the input to the correct type
|
# Convert the input to the correct type
|
||||||
elif inspect.isclass(field_type) and issubclass(
|
elif inspect.isclass(field_type) and issubclass(
|
||||||
field_type, BaseModel
|
field_type, BaseModel
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue