API Updates (#73)

* API Keys passed from Client instead of distro configuration

* delete distribution registry

* Rename the "package" word away

* Introduce a "Router" layer for providers

Some providers need to be factorized and considered as thin routing
layers on top of other providers. Consider two examples:

- The inference API should be a routing layer over inference providers,
  routed using the "model" key
- The memory banks API is another instance where various memory bank
  types will be provided by independent providers (e.g., a vector store
  is served by Chroma while a keyvalue memory can be served by Redis or
  PGVector)

This commit introduces a generalized routing layer for this purpose.

* update `apis_to_serve`

* llama_toolchain -> llama_stack

* Codemod from llama_toolchain -> llama_stack

- added providers/registry
- cleaned up api/ subdirectories and moved impls away
- restructured api/api.py
- from llama_stack.apis.<api> import foo should work now
- update imports to do llama_stack.apis.<api>
- update many other imports
- added __init__, fixed some registry imports
- updated registry imports
- create_agentic_system -> create_agent
- AgenticSystem -> Agent

* Moved some stuff out of common/; re-generated OpenAPI spec

* llama-toolchain -> llama-stack (hyphens)

* add control plane API

* add redis adapter + sqlite provider

* move core -> distribution

* Some more toolchain -> stack changes

* small naming shenanigans

* Removing custom tool and agent utilities and moving them client side

* Move control plane to distribution server for now

* Remove control plane from API list

* no codeshield dependency randomly plzzzzz

* Add "fire" as a dependency

* add back event loggers

* stack configure fixes

* use brave instead of bing in the example client

* add init file so it gets packaged

* add init files so it gets packaged

* Update MANIFEST

* bug fix

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
Co-authored-by: Xi Yan <xiyan@meta.com>
Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
Ashwin Bharambe 2024-09-17 19:51:35 -07:00 committed by GitHub
parent f294eac5f5
commit 9487ad8294
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
213 changed files with 1725 additions and 1204 deletions

View file

@ -1,4 +1,5 @@
include requirements.txt
include llama_toolchain/data/*.yaml
include llama_toolchain/core/*.sh
include llama_toolchain/cli/scripts/*.sh
include llama_stack/distribution/*.sh
include llama_stack/cli/scripts/*.sh
include llama_stack/distribution/example_configs/conda/*.yaml
include llama_stack/distribution/example_configs/docker/*.yaml

View file

@ -1,6 +1,6 @@
# llama-stack
[![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-toolchain)](https://pypi.org/project/llama-toolchain/)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/)
[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/TZAAYNVtrU)
This repository contains the specifications and implementations of the APIs which are part of the Llama Stack.
@ -42,7 +42,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c
## Installation
You can install this repository as a [package](https://pypi.org/project/llama-toolchain/) with `pip install llama-toolchain`
You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack`
If you want to install from source:

View file

@ -1,6 +1,6 @@
# Llama CLI Reference
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-toolchain` package.
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
### Subcommands
1. `download`: `llama` cli tools supports downloading the model from Meta or HuggingFace.
@ -276,16 +276,16 @@ The following command and specifications allows you to get started with building
```
llama stack build <path/to/config>
```
- You will be required to pass in a file path to the build.config file (e.g. `./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_toolchain/configs/distributions/` folder.
- You will be required to pass in a file path to the build.config file (e.g. `./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_stack/distribution/example_configs/` folder.
The file will be of the contents
```
$ cat ./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml
$ cat ./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml
name: 8b-instruct
distribution_spec:
distribution_type: local
description: Use code from `llama_toolchain` itself to serve all llama stack APIs
description: Use code from `llama_stack` itself to serve all llama stack APIs
docker_image: null
providers:
inference: meta-reference
@ -311,7 +311,7 @@ After this step is complete, a file named `8b-instruct-build.yaml` will be gener
To specify a different API provider, we can change the `distribution_spec` in our `<name>-build.yaml` config. For example, the following build spec allows you to build a distribution using TGI as the inference API provider.
```
$ cat ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml
$ cat ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml
name: local-tgi-conda-example
distribution_spec:
@ -328,7 +328,7 @@ image_type: conda
The following command allows you to build a distribution with TGI as the inference API provider, with the name `tgi`.
```
llama stack build --config ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml --name tgi
llama stack build --config ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml --name tgi
```
We provide some example build configs to help you get started with building with different API providers.
@ -337,11 +337,11 @@ We provide some example build configs to help you get started with building with
To build a docker image, simply change the `image_type` to `docker` in our `<name>-build.yaml` file, and run `llama stack build --config <name>-build.yaml`.
```
$ cat ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml
$ cat ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml
name: local-docker-example
distribution_spec:
description: Use code from `llama_toolchain` itself to serve all llama stack APIs
description: Use code from `llama_stack` itself to serve all llama stack APIs
docker_image: null
providers:
inference: meta-reference
@ -354,7 +354,7 @@ image_type: docker
The following command allows you to build a Docker image with the name `docker-local`
```
llama stack build --config ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml --name docker-local
llama stack build --config ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml --name docker-local
Dockerfile created successfully in /tmp/tmp.I0ifS2c46A/DockerfileFROM python:3.10-slim
WORKDIR /app
@ -480,9 +480,9 @@ This server is running a Llama model locally.
Once the server is setup, we can test it with a client to see the example outputs.
```
cd /path/to/llama-stack
conda activate <env> # any environment containing the llama-toolchain pip package will work
conda activate <env> # any environment containing the llama-stack pip package will work
python -m llama_toolchain.inference.client localhost 5000
python -m llama_stack.apis.inference.client localhost 5000
```
This will run the chat completion client and query the distributions /inference/chat_completion API.
@ -500,7 +500,7 @@ You know what's even more hilarious? People like you who think they can just Goo
Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by:
```
python -m llama_toolchain.safety.client localhost 5000
python -m llama_stack.safety.client localhost 5000
```
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.

View file

@ -1,6 +1,6 @@
# Getting Started
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-toolchain` package.
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
This guides allows you to quickly get started with building and running a Llama Stack server in < 5 minutes!
@ -9,7 +9,7 @@ This guides allows you to quickly get started with building and running a Llama
**`llama stack build`**
```
llama stack build --config ./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml --name my-local-llama-stack
llama stack build --config ./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml --name my-local-llama-stack
...
...
Build spec configuration saved at ~/.llama/distributions/conda/my-local-llama-stack-build.yaml
@ -97,16 +97,16 @@ The following command and specifications allows you to get started with building
```
llama stack build <path/to/config>
```
- You will be required to pass in a file path to the build.config file (e.g. `./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_toolchain/configs/distributions/` folder.
- You will be required to pass in a file path to the build.config file (e.g. `./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_stack/distribution/example_configs/` folder.
The file will be of the contents
```
$ cat ./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml
$ cat ./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml
name: 8b-instruct
distribution_spec:
distribution_type: local
description: Use code from `llama_toolchain` itself to serve all llama stack APIs
description: Use code from `llama_stack` itself to serve all llama stack APIs
docker_image: null
providers:
inference: meta-reference
@ -132,7 +132,7 @@ After this step is complete, a file named `8b-instruct-build.yaml` will be gener
To specify a different API provider, we can change the `distribution_spec` in our `<name>-build.yaml` config. For example, the following build spec allows you to build a distribution using TGI as the inference API provider.
```
$ cat ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml
$ cat ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml
name: local-tgi-conda-example
distribution_spec:
@ -149,7 +149,7 @@ image_type: conda
The following command allows you to build a distribution with TGI as the inference API provider, with the name `tgi`.
```
llama stack build --config ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml --name tgi
llama stack build --config ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml --name tgi
```
We provide some example build configs to help you get started with building with different API providers.
@ -158,11 +158,11 @@ We provide some example build configs to help you get started with building with
To build a docker image, simply change the `image_type` to `docker` in our `<name>-build.yaml` file, and run `llama stack build --config <name>-build.yaml`.
```
$ cat ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml
$ cat ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml
name: local-docker-example
distribution_spec:
description: Use code from `llama_toolchain` itself to serve all llama stack APIs
description: Use code from `llama_stack` itself to serve all llama stack APIs
docker_image: null
providers:
inference: meta-reference
@ -175,7 +175,7 @@ image_type: docker
The following command allows you to build a Docker image with the name `docker-local`
```
llama stack build --config ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml --name docker-local
llama stack build --config ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml --name docker-local
Dockerfile created successfully in /tmp/tmp.I0ifS2c46A/DockerfileFROM python:3.10-slim
WORKDIR /app
@ -294,9 +294,9 @@ This server is running a Llama model locally.
Once the server is setup, we can test it with a client to see the example outputs.
```
cd /path/to/llama-stack
conda activate <env> # any environment containing the llama-toolchain pip package will work
conda activate <env> # any environment containing the llama-stack pip package will work
python -m llama_toolchain.inference.client localhost 5000
python -m llama_stack.apis.inference.client localhost 5000
```
This will run the chat completion client and query the distributions /inference/chat_completion API.
@ -314,7 +314,7 @@ You know what's even more hilarious? People like you who think they can just Goo
Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by:
```
python -m llama_toolchain.safety.client localhost 5000
python -m llama_stack.apis.safety.client localhost 5000
```
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo.

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .agents import * # noqa: F401 F403

View file

@ -14,10 +14,10 @@ from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.common.deployment_types import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403
from llama_stack.apis.common.deployment_types import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
@json_schema_type
@ -26,7 +26,7 @@ class Attachment(BaseModel):
mime_type: str
class AgenticSystemTool(Enum):
class AgentTool(Enum):
brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha"
photogen = "photogen"
@ -50,41 +50,35 @@ class SearchEngineType(Enum):
class SearchToolDefinition(ToolDefinitionCommon):
# NOTE: brave_search is just a placeholder since model always uses
# brave_search as tool call name
type: Literal[AgenticSystemTool.brave_search.value] = (
AgenticSystemTool.brave_search.value
)
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
api_key: str
engine: SearchEngineType = SearchEngineType.brave
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class WolframAlphaToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.wolfram_alpha.value] = (
AgenticSystemTool.wolfram_alpha.value
)
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
api_key: str
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class PhotogenToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.code_interpreter.value] = (
AgenticSystemTool.code_interpreter.value
)
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
enable_inline_code_execution: bool = True
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class FunctionCallToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.function_call.value] = (
AgenticSystemTool.function_call.value
)
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
function_name: str
description: str
parameters: Dict[str, ToolParamDefinition]
@ -95,30 +89,30 @@ class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon):
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
keys: List[str] # what keys to focus on
class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon):
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon):
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
AgenticSystemVectorMemoryBankConfig,
AgenticSystemKeyValueMemoryBankConfig,
AgenticSystemKeywordMemoryBankConfig,
AgenticSystemGraphMemoryBankConfig,
AgentVectorMemoryBankConfig,
AgentKeyValueMemoryBankConfig,
AgentKeywordMemoryBankConfig,
AgentGraphMemoryBankConfig,
],
Field(discriminator="type"),
]
@ -158,7 +152,7 @@ MemoryQueryGeneratorConfig = Annotated[
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
# This config defines how a query is generated using the messages
# for memory bank retrieval.
@ -169,7 +163,7 @@ class MemoryToolDefinition(ToolDefinitionCommon):
max_chunks: int = 10
AgenticSystemToolDefinition = Annotated[
AgentToolDefinition = Annotated[
Union[
SearchToolDefinition,
WolframAlphaToolDefinition,
@ -275,7 +269,7 @@ class AgentConfigCommon(BaseModel):
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
@ -292,7 +286,7 @@ class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None
class AgenticSystemTurnResponseEventType(Enum):
class AgentTurnResponseEventType(Enum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
@ -302,9 +296,9 @@ class AgenticSystemTurnResponseEventType(Enum):
@json_schema_type
class AgenticSystemTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = (
AgenticSystemTurnResponseEventType.step_start.value
class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
AgentTurnResponseEventType.step_start.value
)
step_type: StepType
step_id: str
@ -312,20 +306,20 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel):
@json_schema_type
class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = (
AgenticSystemTurnResponseEventType.step_complete.value
class AgentTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
AgentTurnResponseEventType.step_complete.value
)
step_type: StepType
step_details: Step
@json_schema_type
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
class AgentTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
AgenticSystemTurnResponseEventType.step_progress.value
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
AgentTurnResponseEventType.step_progress.value
)
step_type: StepType
step_id: str
@ -336,49 +330,49 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
@json_schema_type
class AgenticSystemTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = (
AgenticSystemTurnResponseEventType.turn_start.value
class AgentTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
AgentTurnResponseEventType.turn_start.value
)
turn_id: str
@json_schema_type
class AgenticSystemTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = (
AgenticSystemTurnResponseEventType.turn_complete.value
class AgentTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
AgentTurnResponseEventType.turn_complete.value
)
turn: Turn
@json_schema_type
class AgenticSystemTurnResponseEvent(BaseModel):
class AgentTurnResponseEvent(BaseModel):
"""Streamed agent execution response."""
payload: Annotated[
Union[
AgenticSystemTurnResponseStepStartPayload,
AgenticSystemTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepCompletePayload,
AgenticSystemTurnResponseTurnStartPayload,
AgenticSystemTurnResponseTurnCompletePayload,
AgentTurnResponseStepStartPayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResponseTurnCompletePayload,
],
Field(discriminator="event_type"),
]
@json_schema_type
class AgenticSystemCreateResponse(BaseModel):
class AgentCreateResponse(BaseModel):
agent_id: str
@json_schema_type
class AgenticSystemSessionCreateResponse(BaseModel):
class AgentSessionCreateResponse(BaseModel):
session_id: str
@json_schema_type
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
agent_id: str
session_id: str
@ -397,24 +391,24 @@ class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
@json_schema_type
class AgenticSystemTurnResponseStreamChunk(BaseModel):
event: AgenticSystemTurnResponseEvent
class AgentTurnResponseStreamChunk(BaseModel):
event: AgentTurnResponseEvent
@json_schema_type
class AgenticSystemStepResponse(BaseModel):
class AgentStepResponse(BaseModel):
step: Step
class AgenticSystem(Protocol):
@webmethod(route="/agentic_system/create")
async def create_agentic_system(
class Agents(Protocol):
@webmethod(route="/agents/create")
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgenticSystemCreateResponse: ...
) -> AgentCreateResponse: ...
@webmethod(route="/agentic_system/turn/create")
async def create_agentic_system_turn(
@webmethod(route="/agents/turn/create")
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
@ -426,42 +420,40 @@ class AgenticSystem(Protocol):
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AgenticSystemTurnResponseStreamChunk: ...
) -> AgentTurnResponseStreamChunk: ...
@webmethod(route="/agentic_system/turn/get")
async def get_agentic_system_turn(
@webmethod(route="/agents/turn/get")
async def get_agents_turn(
self,
agent_id: str,
turn_id: str,
) -> Turn: ...
@webmethod(route="/agentic_system/step/get")
async def get_agentic_system_step(
@webmethod(route="/agents/step/get")
async def get_agents_step(
self, agent_id: str, turn_id: str, step_id: str
) -> AgenticSystemStepResponse: ...
) -> AgentStepResponse: ...
@webmethod(route="/agentic_system/session/create")
async def create_agentic_system_session(
@webmethod(route="/agents/session/create")
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgenticSystemSessionCreateResponse: ...
) -> AgentSessionCreateResponse: ...
@webmethod(route="/agentic_system/session/get")
async def get_agentic_system_session(
@webmethod(route="/agents/session/get")
async def get_agents_session(
self,
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session: ...
@webmethod(route="/agentic_system/session/delete")
async def delete_agentic_system_session(
self, agent_id: str, session_id: str
) -> None: ...
@webmethod(route="/agents/session/delete")
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ...
@webmethod(route="/agentic_system/delete")
async def delete_agentic_system(
@webmethod(route="/agents/delete")
async def delete_agents(
self,
agent_id: str,
) -> None: ...

View file

@ -6,56 +6,58 @@
import asyncio
import json
import os
from typing import AsyncGenerator
import fire
import httpx
from dotenv import load_dotenv
from pydantic import BaseModel
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.core.datatypes import RemoteProviderConfig
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
from .agents import * # noqa: F403
from .event_logger import EventLogger
load_dotenv()
async def get_client_impl(config: RemoteProviderConfig, _deps):
return AgenticSystemClient(config.url)
return AgentsClient(config.url)
def encodable_dict(d: BaseModel):
return json.loads(d.json())
class AgenticSystemClient(AgenticSystem):
class AgentsClient(Agents):
def __init__(self, base_url: str):
self.base_url = base_url
async def create_agentic_system(
self, agent_config: AgentConfig
) -> AgenticSystemCreateResponse:
async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/agentic_system/create",
f"{self.base_url}/agents/create",
json={
"agent_config": encodable_dict(agent_config),
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return AgenticSystemCreateResponse(**response.json())
return AgentCreateResponse(**response.json())
async def create_agentic_system_session(
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgenticSystemSessionCreateResponse:
) -> AgentSessionCreateResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/agentic_system/session/create",
f"{self.base_url}/agents/session/create",
json={
"agent_id": agent_id,
"session_name": session_name,
@ -63,16 +65,16 @@ class AgenticSystemClient(AgenticSystem):
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return AgenticSystemSessionCreateResponse(**response.json())
return AgentSessionCreateResponse(**response.json())
async def create_agentic_system_turn(
async def create_agent_turn(
self,
request: AgenticSystemTurnCreateRequest,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.base_url}/agentic_system/turn/create",
f"{self.base_url}/agents/turn/create",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
@ -86,7 +88,7 @@ class AgenticSystemClient(AgenticSystem):
cprint(data, "red")
continue
yield AgenticSystemTurnResponseStreamChunk(**jdata)
yield AgentTurnResponseStreamChunk(**jdata)
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
@ -102,16 +104,16 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
tool_prompt_format=ToolPromptFormat.function_tag,
)
create_response = await api.create_agentic_system(agent_config)
session_response = await api.create_agentic_system_session(
create_response = await api.create_agent(agent_config)
session_response = await api.create_agent_session(
agent_id=create_response.agent_id,
session_name="test_session",
)
for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agentic_system_turn(
AgenticSystemTurnCreateRequest(
iterator = api.create_agent_turn(
AgentTurnCreateRequest(
agent_id=create_response.agent_id,
session_id=session_response.session_id,
messages=[
@ -128,11 +130,14 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
async def run_main(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}")
api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [
SearchToolDefinition(engine=SearchEngineType.bing),
WolframAlphaToolDefinition(),
SearchToolDefinition(
engine=SearchEngineType.brave,
api_key=os.getenv("BRAVE_SEARCH_API_KEY"),
),
WolframAlphaToolDefinition(api_key=os.getenv("WOLFRAM_ALPHA_API_KEY")),
CodeInterpreterToolDefinition(),
]
tool_definitions += [
@ -165,7 +170,7 @@ async def run_main(host: str, port: int):
async def run_rag(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}")
api = AgentsClient(f"http://{host}:{port}")
urls = [
"memory_optimizations.rst",
@ -186,7 +191,7 @@ async def run_rag(host: str, port: int):
]
# Alternatively, you can pre-populate the memory bank with documents for example,
# using `llama_toolchain.memory.client`. Then you can grab the bank_id
# using `llama_stack.memory.client`. Then you can grab the bank_id
# from the output of that run.
tool_definitions = [
MemoryToolDefinition(

View file

@ -9,12 +9,9 @@ from typing import Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_toolchain.agentic_system.api import (
AgenticSystemTurnResponseEventType,
StepType,
)
from termcolor import cprint
class LogEvent:
@ -40,7 +37,7 @@ class LogEvent:
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
EventType = AgenticSystemTurnResponseEventType
EventType = AgentTurnResponseEventType
class EventLogger:

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .batch_inference import * # noqa: F401 F403

View file

@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
@json_schema_type

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .dataset import * # noqa: F401 F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .evals import * # noqa: F401 F403

View file

@ -12,8 +12,8 @@ from llama_models.schema_utils import webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.dataset.api import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
class TextGenerationMetric(Enum):

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inference import * # noqa: F401 F403

View file

@ -11,11 +11,13 @@ from typing import Any, AsyncGenerator
import fire
import httpx
from llama_toolchain.core.datatypes import RemoteProviderConfig
from llama_stack.distribution.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint
from .api import (
from .event_logger import EventLogger
from .inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
@ -23,7 +25,6 @@ from .api import (
Inference,
UserMessage,
)
from .event_logger import EventLogger
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_toolchain.inference.api import (
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
)

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .memory import * # noqa: F401 F403

View file

@ -6,17 +6,18 @@
import asyncio
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
import fire
import httpx
from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint
from llama_toolchain.core.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
from .memory import * # noqa: F403
from .common.file_utils import data_url_from_file

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .models import * # noqa: F401 F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .post_training import * # noqa: F401 F403

View file

@ -14,8 +14,8 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.dataset.api import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
class OptimizerType(Enum):

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .reward_scoring import * # noqa: F401 F403

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .safety import * # noqa: F401 F403

View file

@ -14,11 +14,11 @@ import httpx
from llama_models.llama3.api.datatypes import UserMessage
from llama_toolchain.core.datatypes import RemoteProviderConfig
from llama_stack.distribution.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint
from .api import * # noqa: F403
from .safety import * # noqa: F403
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:

View file

@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, validator
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
@json_schema_type

34
llama_stack/apis/stack.py Normal file
View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.batch_inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.apis.reward_scoring import * # noqa: F403
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
class LlamaStack(
Inference,
BatchInference,
Agents,
RewardScoring,
Safety,
SyntheticDataGeneration,
Datasets,
Telemetry,
PostTraining,
Memory,
Evaluations,
):
pass

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .synthetic_data_generation import * # noqa: F401 F403

View file

@ -13,7 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.reward_scoring.api import * # noqa: F403
from llama_stack.apis.reward_scoring import * # noqa: F403
class FilteringFunction(Enum):

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .telemetry import * # noqa: F401 F403

View file

@ -20,7 +20,7 @@ from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_stack.cli.subcommand import Subcommand
class Download(Subcommand):
@ -92,7 +92,7 @@ def _hf_download(
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from llama_toolchain.common.model_utils import model_local_dir
from llama_stack.distribution.utils.model_utils import model_local_dir
repo_id = model.huggingface_repo
if repo_id is None:
@ -106,7 +106,7 @@ def _hf_download(
local_dir=output_dir,
ignore_patterns=ignore_patterns,
token=hf_token,
library_name="llama-toolchain",
library_name="llama-stack",
)
except GatedRepoError:
parser.error(
@ -126,7 +126,7 @@ def _hf_download(
def _meta_download(model: "Model", meta_url: str):
from llama_models.sku_list import llama_meta_net_info
from llama_toolchain.common.model_utils import model_local_dir
from llama_stack.distribution.utils.model_utils import model_local_dir
output_dir = Path(model_local_dir(model.descriptor()))
os.makedirs(output_dir, exist_ok=True)
@ -188,7 +188,7 @@ class Manifest(BaseModel):
def _download_from_manifest(manifest_file: str):
from llama_toolchain.common.model_utils import model_local_dir
from llama_stack.distribution.utils.model_utils import model_local_dir
with open(manifest_file, "r") as f:
d = json.load(f)

View file

@ -31,16 +31,6 @@ class LlamaCLIParser:
ModelParser.create(subparsers)
StackParser.create(subparsers)
# Import sub-commands from agentic_system if they exist
try:
from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES
for module in SUBCOMMAND_MODULES:
module.create(subparsers)
except ImportError:
pass
def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()

View file

@ -9,12 +9,12 @@ import json
from llama_models.sku_list import resolve_model
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table
from llama_toolchain.common.serialize import EnumEncoder
from termcolor import colored
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.distribution.utils.serialize import EnumEncoder
class ModelDescribe(Subcommand):
"""Show details about a model"""

View file

@ -6,7 +6,7 @@
import argparse
from llama_toolchain.cli.subcommand import Subcommand
from llama_stack.cli.subcommand import Subcommand
class ModelDownload(Subcommand):
@ -19,6 +19,6 @@ class ModelDownload(Subcommand):
formatter_class=argparse.RawTextHelpFormatter,
)
from llama_toolchain.cli.download import setup_download_parser
from llama_stack.cli.download import setup_download_parser
setup_download_parser(self.parser)

View file

@ -8,8 +8,8 @@ import argparse
from llama_models.sku_list import all_registered_models
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
class ModelList(Subcommand):

View file

@ -6,12 +6,12 @@
import argparse
from llama_toolchain.cli.model.describe import ModelDescribe
from llama_toolchain.cli.model.download import ModelDownload
from llama_toolchain.cli.model.list import ModelList
from llama_toolchain.cli.model.template import ModelTemplate
from llama_stack.cli.model.describe import ModelDescribe
from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.template import ModelTemplate
from llama_toolchain.cli.subcommand import Subcommand
from llama_stack.cli.subcommand import Subcommand
class ModelParser(Subcommand):

View file

@ -9,7 +9,7 @@ import textwrap
from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand
from llama_stack.cli.subcommand import Subcommand
class ModelTemplate(Subcommand):
@ -75,7 +75,7 @@ class ModelTemplate(Subcommand):
render_jinja_template,
)
from llama_toolchain.cli.table import print_table
from llama_stack.cli.table import print_table
if args.name:
tool_prompt_format = self._prompt_type(args.format)

View file

@ -6,8 +6,8 @@
import argparse
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.core.datatypes import * # noqa: F403
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import * # noqa: F403
from pathlib import Path
import yaml
@ -29,7 +29,7 @@ class StackBuild(Subcommand):
self.parser.add_argument(
"config",
type=str,
help="Path to a config file to use for the build. You may find example configs in llama_toolchain/configs/distributions",
help="Path to a config file to use for the build. You may find example configs in llama_stack/distribution/example_configs",
)
self.parser.add_argument(
@ -44,17 +44,17 @@ class StackBuild(Subcommand):
import json
import os
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.core.package import ApiInput, build_package, ImageType
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.serialize import EnumEncoder
from llama_stack.distribution.build import ApiInput, build_image, ImageType
from termcolor import cprint
# save build.yaml spec for building same distribution again
if build_config.image_type == ImageType.docker.value:
# docker needs build file to be in the llama-stack repo dir to be able to copy over to the image
llama_toolchain_path = Path(os.path.relpath(__file__)).parent.parent.parent
llama_stack_path = Path(os.path.relpath(__file__)).parent.parent.parent
build_dir = (
llama_toolchain_path / "configs/distributions" / build_config.image_type
llama_stack_path / "configs/distributions" / build_config.image_type
)
else:
build_dir = DISTRIBS_BASE_DIR / build_config.image_type
@ -66,7 +66,7 @@ class StackBuild(Subcommand):
to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder))
f.write(yaml.dump(to_write, sort_keys=False))
build_package(build_config, build_file_path)
build_image(build_config, build_file_path)
cprint(
f"Build spec configuration saved at {str(build_file_path)}",
@ -74,12 +74,12 @@ class StackBuild(Subcommand):
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.core.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.distribution.utils.dynamic import instantiate_class_type
if not args.config:
self.parser.error(
"No config file specified. Please use `llama stack build /path/to/*-build.yaml`. Example config files can be found in llama_toolchain/configs/distributions"
"No config file specified. Please use `llama stack build /path/to/*-build.yaml`. Example config files can be found in llama_stack/distribution/example_configs"
)
return

View file

@ -13,11 +13,11 @@ import pkg_resources
import yaml
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.core.datatypes import * # noqa: F403
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.datatypes import * # noqa: F403
import os
@ -49,7 +49,7 @@ class StackConfigure(Subcommand):
)
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.core.package import ImageType
from llama_stack.distribution.build import ImageType
docker_image = None
build_config_file = Path(args.config)
@ -66,7 +66,7 @@ class StackConfigure(Subcommand):
os.makedirs(builds_dir, exist_ok=True)
script = pkg_resources.resource_filename(
"llama_toolchain", "core/configure_container.sh"
"llama_stack", "distribution/configure_container.sh"
)
script_args = [script, docker_image, str(builds_dir)]
@ -95,8 +95,8 @@ class StackConfigure(Subcommand):
build_config: BuildConfig,
output_dir: Optional[str] = None,
):
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.core.configure import configure_api_providers
from llama_stack.distribution.configure import configure_api_providers
from llama_stack.distribution.utils.serialize import EnumEncoder
builds_dir = BUILDS_BASE_DIR / build_config.image_type
if output_dir:
@ -105,16 +105,9 @@ class StackConfigure(Subcommand):
image_name = build_config.name.replace("::", "-")
run_config_file = builds_dir / f"{image_name}-run.yaml"
api2providers = build_config.distribution_spec.providers
stub_config = {
api_str: {"provider_id": provider}
for api_str, provider in api2providers.items()
}
if run_config_file.exists():
cprint(
f"Configuration already exists for {build_config.name}. Will overwrite...",
f"Configuration already exists at `{str(run_config_file)}`. Will overwrite...",
"yellow",
attrs=["bold"],
)
@ -123,10 +116,12 @@ class StackConfigure(Subcommand):
config = StackRunConfig(
built_at=datetime.now(),
image_name=image_name,
providers=stub_config,
apis_to_serve=[],
provider_map={},
)
config.providers = configure_api_providers(config.providers)
config = configure_api_providers(config, build_config.distribution_spec)
config.docker_image = (
image_name if build_config.image_type == "docker" else None
)

View file

@ -6,7 +6,7 @@
import argparse
from llama_toolchain.cli.subcommand import Subcommand
from llama_stack.cli.subcommand import Subcommand
class StackListApis(Subcommand):
@ -25,8 +25,8 @@ class StackListApis(Subcommand):
pass
def _run_apis_list_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.cli.table import print_table
from llama_toolchain.core.distribution import stack_apis
from llama_stack.cli.table import print_table
from llama_stack.distribution.distribution import stack_apis
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [

View file

@ -6,7 +6,7 @@
import argparse
from llama_toolchain.cli.subcommand import Subcommand
from llama_stack.cli.subcommand import Subcommand
class StackListProviders(Subcommand):
@ -22,7 +22,7 @@ class StackListProviders(Subcommand):
self.parser.set_defaults(func=self._run_providers_list_cmd)
def _add_arguments(self):
from llama_toolchain.core.distribution import stack_apis
from llama_stack.distribution.distribution import stack_apis
api_values = [a.value for a in stack_apis()]
self.parser.add_argument(
@ -33,8 +33,8 @@ class StackListProviders(Subcommand):
)
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.cli.table import print_table
from llama_toolchain.core.distribution import Api, api_providers
from llama_stack.cli.table import print_table
from llama_stack.distribution.distribution import Api, api_providers
all_providers = api_providers()
providers_for_api = all_providers[Api(args.api)]

View file

@ -11,8 +11,8 @@ from pathlib import Path
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.core.datatypes import * # noqa: F403
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import * # noqa: F403
class StackRun(Subcommand):
@ -47,7 +47,7 @@ class StackRun(Subcommand):
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_stack.distribution.utils.exec import run_with_pty
if not args.config:
self.parser.error("Must specify a config file to run")
@ -67,14 +67,14 @@ class StackRun(Subcommand):
if config.docker_image:
script = pkg_resources.resource_filename(
"llama_toolchain",
"core/start_container.sh",
"llama_stack",
"distribution/start_container.sh",
)
run_args = [script, config.docker_image]
else:
script = pkg_resources.resource_filename(
"llama_toolchain",
"core/start_conda_env.sh",
"llama_stack",
"distribution/start_conda_env.sh",
)
run_args = [
script,

View file

@ -6,7 +6,7 @@
import argparse
from llama_toolchain.cli.subcommand import Subcommand
from llama_stack.cli.subcommand import Subcommand
from .build import StackBuild
from .configure import StackConfigure

View file

@ -4,26 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
from datetime import datetime
from enum import Enum
from typing import List, Optional
import pkg_resources
import yaml
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.common.serialize import EnumEncoder
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.core.datatypes import * # noqa: F403
from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.datatypes import * # noqa: F403
from pathlib import Path
from llama_toolchain.core.distribution import api_providers, SERVER_DEPENDENCIES
from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES
class ImageType(Enum):
@ -41,7 +35,7 @@ class ApiInput(BaseModel):
provider: str
def build_package(build_config: BuildConfig, build_file_path: Path):
def build_image(build_config: BuildConfig, build_file_path: Path):
package_deps = Dependencies(
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
pip_packages=SERVER_DEPENDENCIES,
@ -49,8 +43,19 @@ def build_package(build_config: BuildConfig, build_file_path: Path):
# extend package dependencies based on providers spec
all_providers = api_providers()
for api_str, provider in build_config.distribution_spec.providers.items():
for (
api_str,
provider_or_providers,
) in build_config.distribution_spec.providers.items():
providers_for_api = all_providers[Api(api_str)]
providers = (
provider_or_providers
if isinstance(provider_or_providers, list)
else [provider_or_providers]
)
for provider in providers:
if provider not in providers_for_api:
raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`"
@ -63,7 +68,7 @@ def build_package(build_config: BuildConfig, build_file_path: Path):
if build_config.image_type == ImageType.docker.value:
script = pkg_resources.resource_filename(
"llama_toolchain", "core/build_container.sh"
"llama_stack", "distribution/build_container.sh"
)
args = [
script,
@ -74,7 +79,7 @@ def build_package(build_config: BuildConfig, build_file_path: Path):
]
else:
script = pkg_resources.resource_filename(
"llama_toolchain", "core/build_conda_env.sh"
"llama_stack", "distribution/build_conda_env.sh"
)
args = [
script,

View file

@ -7,11 +7,11 @@
# the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
echo "Using llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR"
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
@ -78,19 +78,19 @@ ensure_conda_env_python310() {
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
pip install fastapi libcst
pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-toolchain==$TEST_PYPI_VERSION $pip_dependencies
pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION $pip_dependencies
else
# Re-installing llama-toolchain in the new conda environment
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
printf "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}\n" >&2
# Re-installing llama-stack in the new conda environment
if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_TOOLCHAIN_DIR: $LLAMA_TOOLCHAIN_DIR\n"
pip install --no-cache-dir -e "$LLAMA_TOOLCHAIN_DIR"
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
else
pip install --no-cache-dir llama-toolchain
pip install --no-cache-dir llama-stack
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then

View file

@ -1,7 +1,7 @@
#!/bin/bash
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ "$#" -ne 4 ]; then
@ -55,17 +55,17 @@ RUN apt-get update && apt-get install -y \
EOF
toolchain_mount="/app/llama-toolchain-source"
stack_mount="/app/llama-stack-source"
models_mount="/app/llama-models-source"
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then
echo "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}" >&2
if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then
echo "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}" >&2
exit 1
fi
add_to_docker "RUN pip install $toolchain_mount"
add_to_docker "RUN pip install $stack_mount"
else
add_to_docker "RUN pip install llama-toolchain"
add_to_docker "RUN pip install llama-stack"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
@ -90,7 +90,7 @@ add_to_docker <<EOF
# This would be good in production but for debugging flexibility lets not add it right now
# We need a more solid production ready entrypoint.sh anyway
#
# ENTRYPOINT ["python", "-m", "llama_toolchain.core.server"]
# ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
EOF
@ -101,8 +101,8 @@ cat $TEMP_DIR/Dockerfile
printf "\n"
mounts=""
if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_TOOLCHAIN_DIR):$toolchain_mount"
if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"

View file

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
from llama_stack.distribution.datatypes import * # noqa: F403
from termcolor import cprint
from llama_stack.distribution.distribution import api_providers, stack_apis
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
# These are hacks so we can re-use the `prompt_for_config` utility
# This needs a bunch of work to be made very user friendly.
class ReqApis(BaseModel):
apis_to_serve: List[str]
def make_routing_entry_type(config_class: Any):
class BaseModelWithConfig(BaseModel):
routing_key: str
config: config_class
return BaseModelWithConfig
# TODO: make sure we can deal with existing configuration values correctly
# instead of just overwriting them
def configure_api_providers(
config: StackRunConfig, spec: DistributionSpec
) -> StackRunConfig:
cprint("Configuring APIs to serve...", "white", attrs=["bold"])
print("Enter comma-separated list of APIs to serve:")
apis = config.apis_to_serve or list(spec.providers.keys())
apis = [a for a in apis if a != "telemetry"]
req_apis = ReqApis(
apis_to_serve=apis,
)
req_apis = prompt_for_config(ReqApis, req_apis)
config.apis_to_serve = req_apis.apis_to_serve
print("")
apis = [v.value for v in stack_apis()]
all_providers = api_providers()
for api_str in spec.providers.keys():
if api_str not in apis:
raise ValueError(f"Unknown API `{api_str}`")
cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"])
api = Api(api_str)
provider_or_providers = spec.providers[api_str]
if isinstance(provider_or_providers, list) and len(provider_or_providers) > 1:
print(
"You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n"
)
routing_entries = []
for p in provider_or_providers:
print(f"Configuring provider `{p}`...")
provider_spec = all_providers[api][p]
config_type = instantiate_class_type(provider_spec.config_class)
# TODO: we need to validate the routing keys, and
# perhaps it is better if we break this out into asking
# for a routing key separately from the associated config
wrapper_type = make_routing_entry_type(config_type)
rt_entry = prompt_for_config(wrapper_type, None)
routing_entries.append(
ProviderRoutingEntry(
provider_id=p,
routing_key=rt_entry.routing_key,
config=rt_entry.config.dict(),
)
)
config.provider_map[api_str] = routing_entries
else:
p = (
provider_or_providers[0]
if isinstance(provider_or_providers, list)
else provider_or_providers
)
print(f"Configuring provider `{p}`...")
provider_spec = all_providers[api][p]
config_type = instantiate_class_type(provider_spec.config_class)
try:
provider_config = config.provider_map.get(api_str)
if provider_config:
existing = config_type(**provider_config.config)
else:
existing = None
except Exception:
existing = None
cfg = prompt_for_config(config_type, existing)
config.provider_map[api_str] = GenericProviderConfig(
provider_id=p,
config=cfg.dict(),
)
return config

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .control_plane import * # noqa: F401 F403

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import RedisImplConfig
async def get_adapter_impl(config: RedisImplConfig, _deps):
from .redis import RedisControlPlaneAdapter
impl = RedisControlPlaneAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class RedisImplConfig(BaseModel):
url: str = Field(
description="The URL for the Redis server",
)
namespace: Optional[str] = Field(
default=None,
description="All keys will be prefixed with this namespace",
)

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime, timedelta
from typing import Any, List, Optional
from redis.asyncio import Redis
from llama_stack.apis.control_plane import * # noqa: F403
from .config import RedisImplConfig
class RedisControlPlaneAdapter(ControlPlane):
def __init__(self, config: RedisImplConfig):
self.config = config
async def initialize(self) -> None:
self.redis = Redis.from_url(self.config.url)
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
return f"{self.config.namespace}:{key}"
async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None
) -> None:
key = self._namespaced_key(key)
await self.redis.set(key, value)
if expiration:
await self.redis.expireat(key, expiration)
async def get(self, key: str) -> Optional[ControlPlaneValue]:
key = self._namespaced_key(key)
value = await self.redis.get(key)
if value is None:
return None
ttl = await self.redis.ttl(key)
expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None
return ControlPlaneValue(key=key, value=value, expiration=expiration)
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
await self.redis.delete(key)
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
keys = await self.redis.keys(f"{start_key}*")
result = []
for key in keys:
if key <= end_key:
value = await self.get(key)
if value:
result.append(value)
return result

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SqliteControlPlaneConfig
async def get_provider_impl(config: SqliteControlPlaneConfig, _deps):
from .control_plane import SqliteControlPlane
impl = SqliteControlPlane(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class SqliteControlPlaneConfig(BaseModel):
db_path: str = Field(
description="File path for the sqlite database",
)
table_name: str = Field(
default="llamastack_control_plane",
description="Table into which all the keys will be placed",
)

View file

@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from typing import Any, List, Optional
import aiosqlite
from llama_stack.apis.control_plane import * # noqa: F403
from .config import SqliteControlPlaneConfig
class SqliteControlPlane(ControlPlane):
def __init__(self, config: SqliteControlPlaneConfig):
self.db_path = config.db_path
self.table_name = config.table_name
async def initialize(self):
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key TEXT PRIMARY KEY,
value TEXT,
expiration TIMESTAMP
)
"""
)
await db.commit()
async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None
) -> None:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
(key, json.dumps(value), expiration),
)
await db.commit()
async def get(self, key: str) -> Optional[ControlPlaneValue]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
) as cursor:
row = await cursor.fetchone()
if row is None:
return None
value, expiration = row
return ControlPlaneValue(
key=key, value=json.loads(value), expiration=expiration
)
async def delete(self, key: str) -> None:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit()
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
) as cursor:
result = []
async for row in cursor:
key, value, expiration = row
result.append(
ControlPlaneValue(
key=key, value=json.loads(value), expiration=expiration
)
)
return result

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import Any, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@json_schema_type
class ControlPlaneValue(BaseModel):
key: str
value: Any
expiration: Optional[datetime] = None
@json_schema_type
class ControlPlane(Protocol):
@webmethod(route="/control_plane/set")
async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None
) -> None: ...
@webmethod(route="/control_plane/get", method="GET")
async def get(self, key: str) -> Optional[ControlPlaneValue]: ...
@webmethod(route="/control_plane/delete")
async def delete(self, key: str) -> None: ...
@webmethod(route="/control_plane/range", method="GET")
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: ...

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.control_plane,
provider_id="sqlite",
pip_packages=["aiosqlite"],
module="llama_stack.providers.impls.sqlite.control_plane",
config_class="llama_stack.providers.impls.sqlite.control_plane.SqliteControlPlaneConfig",
),
remote_provider_spec(
Api.control_plane,
AdapterSpec(
adapter_id="redis",
pip_packages=["redis"],
module="llama_stack.providers.adapters.control_plane.redis",
),
),
]

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from llama_models.schema_utils import json_schema_type
@ -17,7 +17,7 @@ from pydantic import BaseModel, Field, validator
class Api(Enum):
inference = "inference"
safety = "safety"
agentic_system = "agentic_system"
agents = "agents"
memory = "memory"
telemetry = "telemetry"
@ -43,6 +43,33 @@ class ProviderSpec(BaseModel):
)
@json_schema_type
class RouterProviderSpec(ProviderSpec):
provider_id: str = "router"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on RouterProviderSpec")
class GenericProviderConfig(BaseModel):
provider_id: str
config: Dict[str, Any]
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(
@ -124,7 +151,7 @@ as being "Llama Stack compatible"
def module(self) -> str:
if self.adapter:
return self.adapter.module
return f"llama_toolchain.{self.api.value}.client"
return f"llama_stack.apis.{self.api.value}.client"
@property
def pip_packages(self) -> List[str]:
@ -140,7 +167,7 @@ def remote_provider_spec(
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_toolchain.core.datatypes.RemoteProviderConfig"
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
@ -156,12 +183,23 @@ class DistributionSpec(BaseModel):
description="Description of the distribution",
)
docker_image: Optional[str] = None
providers: Dict[str, str] = Field(
providers: Dict[str, Union[str, List[str]]] = Field(
default_factory=dict,
description="Provider Types for each of the APIs provided by this distribution",
description="""
Provider Types for each of the APIs provided by this distribution. If you
select multiple providers, you should provide an appropriate 'routing_map'
in the runtime configuration to help route to the correct provider.""",
)
@json_schema_type
class ProviderRoutingEntry(GenericProviderConfig):
routing_key: str
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
@json_schema_type
class StackRunConfig(BaseModel):
built_at: datetime
@ -181,12 +219,22 @@ this could be just a hash
default=None,
description="Reference to the conda environment if this package refers to a conda environment",
)
providers: Dict[str, Any] = Field(
default_factory=dict,
apis_to_serve: List[str] = Field(
description="""
Provider configurations for each of the APIs provided by this package. This includes configurations for
the dependencies of these providers as well.
""",
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
)
provider_map: Dict[str, ProviderMapEntry] = Field(
description="""
Provider configurations for each of the APIs provided by this package.
Given an API, you can specify a single provider or a "routing table". Each entry in the routing
table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific.
As examples:
- the "inference" API interprets the routing_key as a "model"
- the "memory" API interprets the routing_key as the type of a "memory bank"
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
)

View file

@ -8,18 +8,19 @@ import importlib
import inspect
from typing import Dict, List
from llama_toolchain.agentic_system.api import AgenticSystem
from llama_toolchain.inference.api import Inference
from llama_toolchain.memory.api import Memory
from llama_toolchain.safety.api import Safety
from llama_toolchain.telemetry.api import Telemetry
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import Telemetry
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
# These are the dependencies needed by the distribution server.
# `llama-toolchain` is automatically installed by the installation script.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"fastapi",
"fire",
"uvicorn",
]
@ -34,7 +35,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agentic_system: AgenticSystem,
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
}
@ -67,7 +68,7 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
for api in stack_apis():
name = api.name.lower()
module = importlib.import_module(f"llama_toolchain.{name}.providers")
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {
"remote": remote_provider_spec(api),
**{a.provider_id: a for a in module.available_providers()},

View file

@ -0,0 +1,10 @@
name: local-conda-example
distribution_spec:
description: Use code from `llama_stack` itself to serve all llama stack APIs
providers:
inference: meta-reference
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -3,8 +3,8 @@ distribution_spec:
description: Use Fireworks.ai for running LLM inference
providers:
inference: remote::fireworks
memory: meta-reference-faiss
memory: meta-reference
safety: meta-reference
agentic_system: meta-reference
telemetry: console
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -3,8 +3,8 @@ distribution_spec:
description: Like local, but use ollama for running LLM inference
providers:
inference: remote::ollama
memory: meta-reference-faiss
memory: meta-reference
safety: meta-reference
agentic_system: meta-reference
telemetry: console
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -3,8 +3,8 @@ distribution_spec:
description: Use TGI (local or with Hugging Face Inference Endpoints for running LLM inference. When using HF Inference Endpoints, you must provide the name of the endpoint).
providers:
inference: remote::tgi
memory: meta-reference-faiss
memory: meta-reference
safety: meta-reference
agentic_system: meta-reference
telemetry: console
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -3,8 +3,8 @@ distribution_spec:
description: Use Together.ai for running LLM inference
providers:
inference: remote::together
memory: meta-reference-faiss
memory: meta-reference
safety: meta-reference
agentic_system: meta-reference
telemetry: console
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -0,0 +1,10 @@
name: local-docker-example
distribution_spec:
description: Use code from `llama_stack` itself to serve all llama stack APIs
providers:
inference: meta-reference
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: docker

View file

@ -9,6 +9,7 @@ import inspect
import json
import signal
import traceback
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
@ -38,16 +39,16 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_toolchain.telemetry.tracing import (
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
SpanStatus,
start_trace,
)
from llama_stack.distribution.datatypes import * # noqa: F403
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_provider
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.utils.dynamic import instantiate_provider
def is_async_iterator_type(typ):
@ -271,61 +272,80 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
return [by_id[x] for x in stack]
def resolve_impls(
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
) -> Dict[Api, Any]:
provider_configs = config["providers"]
provider_specs = topological_sort(provider_specs.values())
def snake_to_camel(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_"))
impls = {}
for provider_spec in provider_specs:
api = provider_spec.api
if api.value not in provider_configs:
async def resolve_impls(
provider_map: Dict[str, ProviderMapEntry],
) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_providers = api_providers()
specs = {}
for api_str, item in provider_map.items():
api = Api(api_str)
providers = all_providers[api]
if isinstance(item, GenericProviderConfig):
if item.provider_id not in providers:
raise ValueError(
f"Could not find provider_spec config for {api}. Please add it to the config"
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[item.provider_id]
else:
assert isinstance(item, list)
inner_specs = []
for rt_entry in item:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_stack.providers.routers.{api.value.lower()}",
api_dependencies=[],
inner_specs=inner_specs,
)
if isinstance(provider_spec, InlineProviderSpec):
deps = {api: impls[api] for api in provider_spec.api_dependencies}
else:
deps = {}
provider_config = provider_configs[api.value]
impl = instantiate_provider(provider_spec, provider_config, deps)
sorted_specs = topological_sort(specs.values())
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
impl = await instantiate_provider(spec, deps, provider_map[api.value])
impls[api] = impl
return impls
return impls, specs
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp)
config = StackRunConfig(**yaml.safe_load(fp))
app = FastAPI()
all_endpoints = api_endpoints()
all_providers = api_providers()
provider_specs = {}
for api_str, provider_config in config["providers"].items():
api = Api(api_str)
providers = all_providers[api]
provider_id = provider_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_specs[api] = providers[provider_id]
impls = resolve_impls(provider_specs, config)
impls, specs = asyncio.run(resolve_impls(config.provider_map))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
for provider_spec in provider_specs.values():
api = provider_spec.api
all_endpoints = api_endpoints()
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
impl = impls[api]
provider_spec = specs[api]
if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None

View file

@ -37,6 +37,6 @@ eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name"
$CONDA_PREFIX/bin/python \
-m llama_toolchain.core.server \
-m llama_stack.distribution.server.server \
--yaml_config "$yaml_config" \
--port "$port" "$@"

View file

@ -38,6 +38,6 @@ podman run -it \
-p $port:$port \
-v "$yaml_config:/app/config.yaml" \
$docker_image \
python -m llama_toolchain.core.server \
python -m llama_stack.distribution.server.server \
--yaml_config /app/config.yaml \
--port $port "$@"

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
from typing import Any, Dict
from llama_stack.distribution.datatypes import * # noqa: F403
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
deps: Dict[str, Any],
provider_config: ProviderMapEntry,
):
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, RouterProviderSpec):
method = "get_router_impl"
assert isinstance(provider_config, list)
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in provider_config:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_id],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [inner_impls, deps]
else:
method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -27,6 +27,12 @@ def is_list_of_primitives(field_type):
return False
def is_basemodel_without_fields(typ):
return (
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
)
def can_recurse(typ):
return (
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
@ -151,6 +157,11 @@ def prompt_for_config(
if get_origin(field_type) is Literal:
continue
# Skip fields with no type annotations
if is_basemodel_without_fields(field_type):
config_data[field_name] = field_type()
continue
if inspect.isclass(field_type) and issubclass(field_type, Enum):
prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):"
while True:
@ -254,6 +265,20 @@ def prompt_for_config(
print(f"{str(e)}")
continue
elif get_origin(field_type) is dict:
try:
value = json.loads(user_input)
if not isinstance(value, dict):
raise ValueError(
"Input must be a JSON-encoded dictionary"
)
except json.JSONDecodeError:
print(
"Invalid JSON. Please enter a valid JSON-encoded dict."
)
continue
# Convert the input to the correct type
elif inspect.isclass(field_type) and issubclass(
field_type, BaseModel

Some files were not shown because too many files have changed in this diff Show more