From 9487ad8294fd1ca989b091f08d8b9b44d0629e75 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 17 Sep 2024 19:51:35 -0700 Subject: [PATCH] API Updates (#73) * API Keys passed from Client instead of distro configuration * delete distribution registry * Rename the "package" word away * Introduce a "Router" layer for providers Some providers need to be factorized and considered as thin routing layers on top of other providers. Consider two examples: - The inference API should be a routing layer over inference providers, routed using the "model" key - The memory banks API is another instance where various memory bank types will be provided by independent providers (e.g., a vector store is served by Chroma while a keyvalue memory can be served by Redis or PGVector) This commit introduces a generalized routing layer for this purpose. * update `apis_to_serve` * llama_toolchain -> llama_stack * Codemod from llama_toolchain -> llama_stack - added providers/registry - cleaned up api/ subdirectories and moved impls away - restructured api/api.py - from llama_stack.apis. import foo should work now - update imports to do llama_stack.apis. - update many other imports - added __init__, fixed some registry imports - updated registry imports - create_agentic_system -> create_agent - AgenticSystem -> Agent * Moved some stuff out of common/; re-generated OpenAPI spec * llama-toolchain -> llama-stack (hyphens) * add control plane API * add redis adapter + sqlite provider * move core -> distribution * Some more toolchain -> stack changes * small naming shenanigans * Removing custom tool and agent utilities and moving them client side * Move control plane to distribution server for now * Remove control plane from API list * no codeshield dependency randomly plzzzzz * Add "fire" as a dependency * add back event loggers * stack configure fixes * use brave instead of bing in the example client * add init file so it gets packaged * add init files so it gets packaged * Update MANIFEST * bug fix --------- Co-authored-by: Hardik Shah Co-authored-by: Xi Yan Co-authored-by: Ashwin Bharambe --- MANIFEST.in | 7 +- README.md | 4 +- docs/cli_reference.md | 24 +- docs/getting_started.md | 26 +- {llama_toolchain => llama_stack}/__init__.py | 0 .../apis}/__init__.py | 0 llama_stack/apis/agents/__init__.py | 7 + .../apis/agents/agents.py | 152 ++++++------ .../apis/agents}/client.py | 59 ++--- .../apis/agents}/event_logger.py | 9 +- llama_stack/apis/batch_inference/__init__.py | 7 + .../apis/batch_inference/batch_inference.py | 2 +- .../apis/common}/__init__.py | 0 .../apis}/common/deployment_types.py | 0 .../apis}/common/training_types.py | 0 llama_stack/apis/dataset/__init__.py | 7 + .../apis/dataset/dataset.py | 0 llama_stack/apis/evals/__init__.py | 7 + .../api.py => llama_stack/apis/evals/evals.py | 4 +- llama_stack/apis/inference/__init__.py | 7 + .../apis}/inference/client.py | 7 +- .../apis}/inference/event_logger.py | 2 +- .../apis/inference/inference.py | 0 llama_stack/apis/memory/__init__.py | 7 + .../apis}/memory/client.py | 7 +- .../apis/memory/memory.py | 0 llama_stack/apis/models/__init__.py | 7 + .../apis/models/models.py | 0 llama_stack/apis/post_training/__init__.py | 7 + .../apis/post_training/post_training.py | 4 +- llama_stack/apis/reward_scoring/__init__.py | 7 + .../apis/reward_scoring/reward_scoring.py | 0 llama_stack/apis/safety/__init__.py | 7 + .../apis}/safety/client.py | 4 +- .../apis/safety/safety.py | 2 +- llama_stack/apis/stack.py | 34 +++ .../synthetic_data_generation/__init__.py | 7 + .../synthetic_data_generation.py | 2 +- llama_stack/apis/telemetry/__init__.py | 7 + .../apis/telemetry/telemetry.py | 0 .../scripts => llama_stack/cli}/__init__.py | 0 .../cli/download.py | 10 +- {llama_toolchain => llama_stack}/cli/llama.py | 10 - .../cli/model/__init__.py | 0 .../cli/model/describe.py | 8 +- .../cli/model/download.py | 4 +- .../cli/model/list.py | 4 +- .../cli/model/model.py | 10 +- .../cli/model/template.py | 4 +- .../cli/scripts}/__init__.py | 0 .../scripts/install-wheel-from-presigned.sh | 0 .../cli/scripts/run.py | 0 .../cli/stack/__init__.py | 0 .../cli/stack/build.py | 24 +- .../cli/stack/configure.py | 31 +-- .../cli/stack/list_apis.py | 6 +- .../cli/stack/list_providers.py | 8 +- .../cli/stack/run.py | 14 +- .../cli/stack/stack.py | 2 +- .../cli/subcommand.py | 0 {llama_toolchain => llama_stack}/cli/table.py | 0 .../distribution}/__init__.py | 0 .../distribution/build.py | 49 ++-- .../distribution}/build_conda_env.sh | 22 +- .../distribution}/build_container.sh | 20 +- .../distribution}/common.sh | 0 llama_stack/distribution/configure.py | 110 +++++++++ .../distribution}/configure_container.sh | 0 .../distribution/control_plane/__init__.py | 7 + .../control_plane/adapters}/__init__.py | 0 .../control_plane/adapters/redis/__init__.py | 15 ++ .../control_plane/adapters/redis/config.py | 19 ++ .../control_plane/adapters/redis/redis.py | 62 +++++ .../control_plane/adapters/sqlite/__init__.py | 15 ++ .../control_plane/adapters/sqlite/config.py | 19 ++ .../adapters/sqlite/control_plane.py | 79 ++++++ llama_stack/distribution/control_plane/api.py | 35 +++ .../distribution/control_plane/registry.py | 29 +++ .../distribution}/datatypes.py | 70 +++++- .../distribution}/distribution.py | 17 +- .../conda/local-conda-example-build.yaml | 10 + .../local-fireworks-conda-example-build.yaml | 6 +- .../local-ollama-conda-example-build.yaml | 6 +- .../conda/local-tgi-conda-example-build.yaml | 6 +- .../local-together-conda-example-build.yaml | 6 +- .../docker/local-docker-example-build.yaml | 10 + .../distribution/server}/__init__.py | 0 .../distribution/server}/server.py | 102 ++++---- .../distribution}/start_conda_env.sh | 2 +- .../distribution}/start_container.sh | 2 +- .../distribution/utils}/__init__.py | 0 .../distribution/utils}/config_dirs.py | 0 llama_stack/distribution/utils/dynamic.py | 66 +++++ .../distribution/utils}/exec.py | 0 .../distribution/utils}/model_utils.py | 0 .../distribution/utils}/prompt_for_config.py | 25 ++ .../distribution/utils}/serialize.py | 0 .../providers}/__init__.py | 0 .../providers/adapters}/__init__.py | 0 .../providers/adapters/inference}/__init__.py | 0 .../adapters/inference}/fireworks/__init__.py | 0 .../adapters/inference}/fireworks/config.py | 0 .../inference}/fireworks/fireworks.py | 11 +- .../adapters/inference}/ollama/__init__.py | 2 +- .../adapters/inference}/ollama/ollama.py | 9 +- .../adapters/inference}/tgi/__init__.py | 0 .../adapters/inference}/tgi/config.py | 0 .../providers/adapters/inference}/tgi/tgi.py | 8 +- .../adapters/inference}/together/__init__.py | 0 .../adapters/inference}/together/config.py | 0 .../adapters/inference}/together/together.py | 9 +- .../providers/adapters/memory}/__init__.py | 0 .../adapters/memory}/chroma/__init__.py | 2 +- .../adapters/memory}/chroma/chroma.py | 7 +- .../adapters/memory}/pgvector/__init__.py | 0 .../adapters/memory}/pgvector/config.py | 0 .../adapters/memory}/pgvector/pgvector.py | 4 +- .../providers/impls}/__init__.py | 0 .../impls/meta_reference}/__init__.py | 0 .../impls/meta_reference/agents}/__init__.py | 6 +- .../meta_reference/agents}/agent_instance.py | 126 +++++----- .../impls/meta_reference/agents/agents.py | 50 ++-- .../impls/meta_reference/agents/config.py | 10 + .../meta_reference/agents/rag}/__init__.py | 0 .../agents}/rag/context_retriever.py | 4 +- .../impls/meta_reference/agents}/safety.py | 4 +- .../meta_reference/agents/tests}/__init__.py | 0 .../agents/tests/code_execution.py | 93 ++++++++ .../meta_reference/agents/tools}/__init__.py | 2 - .../meta_reference/agents}/tools/base.py | 2 +- .../meta_reference/agents}/tools/builtin.py | 4 +- .../agents/tools/ipython_tool}/__init__.py | 2 - .../tools/ipython_tool/code_env_prefix.py | 0 .../tools/ipython_tool/code_execution.py | 0 .../ipython_tool/matplotlib_custom_backend.py | 0 .../agents}/tools/ipython_tool/utils.py | 0 .../meta_reference/agents}/tools/safety.py | 6 +- .../meta_reference/inference}/__init__.py | 0 .../impls/meta_reference/inference}/config.py | 4 +- .../meta_reference/inference}/generation.py | 4 +- .../meta_reference/inference}/inference.py | 10 +- .../inference}/model_parallel.py | 0 .../inference}/parallel_utils.py | 0 .../inference/quantization}/__init__.py | 2 - .../inference/quantization/fp8_impls.py | 0 .../inference/quantization/loader.py | 4 +- .../quantization/scripts}/__init__.py | 2 - .../quantization/scripts/build_conda.sh | 0 .../scripts/quantize_checkpoint.py | 0 .../scripts/run_quantize_checkpoint.sh | 0 .../inference/quantization/test_fp8.py | 0 .../impls/meta_reference/memory}/__init__.py | 0 .../impls/meta_reference/memory}/config.py | 0 .../impls/meta_reference/memory}/faiss.py | 7 +- .../impls/meta_reference/safety}/__init__.py | 0 .../impls/meta_reference/safety}/config.py | 0 .../impls/meta_reference/safety}/safety.py | 4 +- .../safety}/shields/__init__.py | 0 .../meta_reference/safety}/shields/base.py | 2 +- .../safety}/shields/code_scanner.py | 2 +- .../safety/shields/contrib/__init__.py | 5 + .../shields/contrib/third_party_shield.py | 2 +- .../safety}/shields/llama_guard.py | 2 +- .../safety}/shields/prompt_guard.py | 2 +- .../meta_reference/telemetry}/__init__.py | 0 .../impls/meta_reference/telemetry}/config.py | 0 .../meta_reference/telemetry}/console.py | 2 +- .../providers/impls/sqlite/__init__.py | 5 + llama_stack/providers/registry/__init__.py | 5 + .../providers/registry/agents.py | 9 +- .../providers/registry/inference.py | 21 +- .../providers/registry/memory.py | 22 +- .../providers/registry/safety.py | 6 +- .../providers/registry/telemetry.py | 8 +- llama_stack/providers/routers/__init__.py | 5 + .../providers/routers/memory/__init__.py | 17 ++ .../providers/routers/memory/memory.py | 91 +++++++ llama_stack/providers/utils/__init__.py | 5 + .../providers/utils/inference/__init__.py | 5 + .../utils}/inference/prepare_messages.py | 2 +- .../providers/utils/memory/__init__.py | 5 + .../providers/utils/memory}/file_utils.py | 0 .../providers/utils/memory}/vector_store.py | 2 +- .../providers/utils/telemetry/__init__.py | 5 + .../providers/utils}/telemetry/tracing.py | 2 +- llama_toolchain/agentic_system/__init__.py | 0 .../execute_with_custom_tools.py | 96 -------- .../agentic_system/meta_reference/config.py | 15 -- .../conda/local-conda-example-build.yaml | 10 - .../docker/local-docker-example-build.yaml | 10 - llama_toolchain/core/configure.py | 50 ---- llama_toolchain/core/dynamic.py | 42 ---- llama_toolchain/inference/api/__init__.py | 7 - llama_toolchain/memory/api/__init__.py | 7 - llama_toolchain/post_training/api/__init__.py | 7 - .../reward_scoring/api/__init__.py | 7 - llama_toolchain/safety/api/__init__.py | 7 - llama_toolchain/stack.py | 34 --- .../synthetic_data_generation/api/__init__.py | 7 - llama_toolchain/telemetry/api/__init__.py | 7 - llama_toolchain/tools/custom/datatypes.py | 98 -------- requirements.txt | 1 + .../llama-stack-spec.html | 225 +++++++++--------- .../llama-stack-spec.yaml | 214 +++++++++-------- rfcs/RFC-0001-llama-stack.md | 20 +- rfcs/openapi_generator/README.md | 2 +- rfcs/openapi_generator/generate.py | 2 +- setup.py | 10 +- tests/example_custom_tool.py | 2 +- tests/test_e2e.py | 8 +- tests/test_inference.py | 12 +- tests/test_ollama_inference.py | 12 +- tests/test_prepare_messages.py | 10 +- 213 files changed, 1725 insertions(+), 1204 deletions(-) rename {llama_toolchain => llama_stack}/__init__.py (100%) rename {llama_toolchain/batch_inference => llama_stack/apis}/__init__.py (100%) create mode 100644 llama_stack/apis/agents/__init__.py rename llama_toolchain/agentic_system/api/api.py => llama_stack/apis/agents/agents.py (68%) rename {llama_toolchain/agentic_system => llama_stack/apis/agents}/client.py (79%) rename {llama_toolchain/agentic_system => llama_stack/apis/agents}/event_logger.py (97%) create mode 100644 llama_stack/apis/batch_inference/__init__.py rename llama_toolchain/batch_inference/api/api.py => llama_stack/apis/batch_inference/batch_inference.py (97%) rename {llama_toolchain/cli => llama_stack/apis/common}/__init__.py (100%) rename {llama_toolchain => llama_stack/apis}/common/deployment_types.py (100%) rename {llama_toolchain => llama_stack/apis}/common/training_types.py (100%) create mode 100644 llama_stack/apis/dataset/__init__.py rename llama_toolchain/dataset/api/api.py => llama_stack/apis/dataset/dataset.py (100%) create mode 100644 llama_stack/apis/evals/__init__.py rename llama_toolchain/evaluations/api/api.py => llama_stack/apis/evals/evals.py (95%) create mode 100644 llama_stack/apis/inference/__init__.py rename {llama_toolchain => llama_stack/apis}/inference/client.py (97%) rename {llama_toolchain => llama_stack/apis}/inference/event_logger.py (97%) rename llama_toolchain/inference/api/api.py => llama_stack/apis/inference/inference.py (100%) create mode 100644 llama_stack/apis/memory/__init__.py rename {llama_toolchain => llama_stack/apis}/memory/client.py (97%) rename llama_toolchain/memory/api/api.py => llama_stack/apis/memory/memory.py (100%) create mode 100644 llama_stack/apis/models/__init__.py rename llama_toolchain/models/api/api.py => llama_stack/apis/models/models.py (100%) create mode 100644 llama_stack/apis/post_training/__init__.py rename llama_toolchain/post_training/api/api.py => llama_stack/apis/post_training/post_training.py (97%) create mode 100644 llama_stack/apis/reward_scoring/__init__.py rename llama_toolchain/reward_scoring/api/api.py => llama_stack/apis/reward_scoring/reward_scoring.py (100%) create mode 100644 llama_stack/apis/safety/__init__.py rename {llama_toolchain => llama_stack/apis}/safety/client.py (95%) rename llama_toolchain/safety/api/api.py => llama_stack/apis/safety/safety.py (96%) create mode 100644 llama_stack/apis/stack.py create mode 100644 llama_stack/apis/synthetic_data_generation/__init__.py rename llama_toolchain/synthetic_data_generation/api/api.py => llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py (96%) create mode 100644 llama_stack/apis/telemetry/__init__.py rename llama_toolchain/telemetry/api/api.py => llama_stack/apis/telemetry/telemetry.py (100%) rename {llama_toolchain/cli/scripts => llama_stack/cli}/__init__.py (100%) rename {llama_toolchain => llama_stack}/cli/download.py (97%) rename {llama_toolchain => llama_stack}/cli/llama.py (80%) rename {llama_toolchain => llama_stack}/cli/model/__init__.py (100%) rename {llama_toolchain => llama_stack}/cli/model/describe.py (93%) rename {llama_toolchain => llama_stack}/cli/model/download.py (83%) rename {llama_toolchain => llama_stack}/cli/model/list.py (94%) rename {llama_toolchain => llama_stack}/cli/model/model.py (73%) rename {llama_toolchain => llama_stack}/cli/model/template.py (97%) rename {llama_toolchain/common => llama_stack/cli/scripts}/__init__.py (100%) rename {llama_toolchain => llama_stack}/cli/scripts/install-wheel-from-presigned.sh (100%) rename {llama_toolchain => llama_stack}/cli/scripts/run.py (100%) rename {llama_toolchain => llama_stack}/cli/stack/__init__.py (100%) rename {llama_toolchain => llama_stack}/cli/stack/build.py (76%) rename {llama_toolchain => llama_stack}/cli/stack/configure.py (82%) rename {llama_toolchain => llama_stack}/cli/stack/list_apis.py (87%) rename {llama_toolchain => llama_stack}/cli/stack/list_providers.py (87%) rename {llama_toolchain => llama_stack}/cli/stack/run.py (87%) rename {llama_toolchain => llama_stack}/cli/stack/stack.py (94%) rename {llama_toolchain => llama_stack}/cli/subcommand.py (100%) rename {llama_toolchain => llama_stack}/cli/table.py (100%) rename {llama_toolchain/core => llama_stack/distribution}/__init__.py (100%) rename llama_toolchain/core/package.py => llama_stack/distribution/build.py (58%) rename {llama_toolchain/core => llama_stack/distribution}/build_conda_env.sh (81%) rename {llama_toolchain/core => llama_stack/distribution}/build_container.sh (80%) rename {llama_toolchain/core => llama_stack/distribution}/common.sh (100%) create mode 100644 llama_stack/distribution/configure.py rename {llama_toolchain/core => llama_stack/distribution}/configure_container.sh (100%) create mode 100644 llama_stack/distribution/control_plane/__init__.py rename {llama_toolchain/inference => llama_stack/distribution/control_plane/adapters}/__init__.py (100%) create mode 100644 llama_stack/distribution/control_plane/adapters/redis/__init__.py create mode 100644 llama_stack/distribution/control_plane/adapters/redis/config.py create mode 100644 llama_stack/distribution/control_plane/adapters/redis/redis.py create mode 100644 llama_stack/distribution/control_plane/adapters/sqlite/__init__.py create mode 100644 llama_stack/distribution/control_plane/adapters/sqlite/config.py create mode 100644 llama_stack/distribution/control_plane/adapters/sqlite/control_plane.py create mode 100644 llama_stack/distribution/control_plane/api.py create mode 100644 llama_stack/distribution/control_plane/registry.py rename {llama_toolchain/core => llama_stack/distribution}/datatypes.py (72%) rename {llama_toolchain/core => llama_stack/distribution}/distribution.py (79%) create mode 100644 llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml rename {llama_toolchain/configs/distributions => llama_stack/distribution/example_configs}/conda/local-fireworks-conda-example-build.yaml (69%) rename {llama_toolchain/configs/distributions => llama_stack/distribution/example_configs}/conda/local-ollama-conda-example-build.yaml (69%) rename {llama_toolchain/configs/distributions => llama_stack/distribution/example_configs}/conda/local-tgi-conda-example-build.yaml (77%) rename {llama_toolchain/configs/distributions => llama_stack/distribution/example_configs}/conda/local-together-conda-example-build.yaml (68%) create mode 100644 llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml rename {llama_toolchain/inference/adapters => llama_stack/distribution/server}/__init__.py (100%) rename {llama_toolchain/core => llama_stack/distribution/server}/server.py (82%) rename {llama_toolchain/core => llama_stack/distribution}/start_conda_env.sh (94%) rename {llama_toolchain/core => llama_stack/distribution}/start_container.sh (93%) rename {llama_toolchain/memory => llama_stack/distribution/utils}/__init__.py (100%) rename {llama_toolchain/common => llama_stack/distribution/utils}/config_dirs.py (100%) create mode 100644 llama_stack/distribution/utils/dynamic.py rename {llama_toolchain/common => llama_stack/distribution/utils}/exec.py (100%) rename {llama_toolchain/common => llama_stack/distribution/utils}/model_utils.py (100%) rename {llama_toolchain/common => llama_stack/distribution/utils}/prompt_for_config.py (91%) rename {llama_toolchain/common => llama_stack/distribution/utils}/serialize.py (100%) rename {llama_toolchain/memory/common => llama_stack/providers}/__init__.py (100%) rename {llama_toolchain/memory/meta_reference => llama_stack/providers/adapters}/__init__.py (100%) rename {llama_toolchain/safety => llama_stack/providers/adapters/inference}/__init__.py (100%) rename {llama_toolchain/inference/adapters => llama_stack/providers/adapters/inference}/fireworks/__init__.py (100%) rename {llama_toolchain/inference/adapters => llama_stack/providers/adapters/inference}/fireworks/config.py (100%) rename {llama_toolchain/inference/adapters => llama_stack/providers/adapters/inference}/fireworks/fireworks.py (97%) rename {llama_toolchain/inference/adapters => llama_stack/providers/adapters/inference}/ollama/__init__.py (85%) rename {llama_toolchain/inference/adapters => llama_stack/providers/adapters/inference}/ollama/ollama.py (97%) rename {llama_toolchain/inference/adapters => llama_stack/providers/adapters/inference}/tgi/__init__.py (100%) rename {llama_toolchain/inference/adapters => llama_stack/providers/adapters/inference}/tgi/config.py (100%) rename {llama_toolchain/inference/adapters => llama_stack/providers/adapters/inference}/tgi/tgi.py (97%) rename {llama_toolchain/inference/adapters => llama_stack/providers/adapters/inference}/together/__init__.py (100%) rename {llama_toolchain/inference/adapters => llama_stack/providers/adapters/inference}/together/config.py (100%) rename {llama_toolchain/inference/adapters => llama_stack/providers/adapters/inference}/together/together.py (97%) rename {llama_toolchain/safety/meta_reference/shields/contrib => llama_stack/providers/adapters/memory}/__init__.py (100%) rename {llama_toolchain/memory/adapters => llama_stack/providers/adapters/memory}/chroma/__init__.py (85%) rename {llama_toolchain/memory/adapters => llama_stack/providers/adapters/memory}/chroma/chroma.py (97%) rename {llama_toolchain/memory/adapters => llama_stack/providers/adapters/memory}/pgvector/__init__.py (100%) rename {llama_toolchain/memory/adapters => llama_stack/providers/adapters/memory}/pgvector/config.py (100%) rename {llama_toolchain/memory/adapters => llama_stack/providers/adapters/memory}/pgvector/pgvector.py (98%) rename {llama_toolchain/telemetry => llama_stack/providers/impls}/__init__.py (100%) rename {llama_toolchain/tools => llama_stack/providers/impls/meta_reference}/__init__.py (100%) rename {llama_toolchain/agentic_system/meta_reference => llama_stack/providers/impls/meta_reference/agents}/__init__.py (79%) rename {llama_toolchain/agentic_system/meta_reference => llama_stack/providers/impls/meta_reference/agents}/agent_instance.py (87%) rename llama_toolchain/agentic_system/meta_reference/agentic_system.py => llama_stack/providers/impls/meta_reference/agents/agents.py (71%) create mode 100644 llama_stack/providers/impls/meta_reference/agents/config.py rename {llama_toolchain/tools/custom => llama_stack/providers/impls/meta_reference/agents/rag}/__init__.py (100%) rename {llama_toolchain/agentic_system/meta_reference => llama_stack/providers/impls/meta_reference/agents}/rag/context_retriever.py (95%) rename {llama_toolchain/agentic_system/meta_reference => llama_stack/providers/impls/meta_reference/agents}/safety.py (98%) rename {llama_toolchain/tools/ipython_tool => llama_stack/providers/impls/meta_reference/agents/tests}/__init__.py (100%) create mode 100644 llama_stack/providers/impls/meta_reference/agents/tests/code_execution.py rename {llama_toolchain/dataset/api => llama_stack/providers/impls/meta_reference/agents/tools}/__init__.py (83%) rename {llama_toolchain => llama_stack/providers/impls/meta_reference/agents}/tools/base.py (90%) rename {llama_toolchain => llama_stack/providers/impls/meta_reference/agents}/tools/builtin.py (99%) rename {llama_toolchain/agentic_system/api => llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool}/__init__.py (83%) rename {llama_toolchain => llama_stack/providers/impls/meta_reference/agents}/tools/ipython_tool/code_env_prefix.py (100%) rename {llama_toolchain => llama_stack/providers/impls/meta_reference/agents}/tools/ipython_tool/code_execution.py (100%) rename {llama_toolchain => llama_stack/providers/impls/meta_reference/agents}/tools/ipython_tool/matplotlib_custom_backend.py (100%) rename {llama_toolchain => llama_stack/providers/impls/meta_reference/agents}/tools/ipython_tool/utils.py (100%) rename {llama_toolchain => llama_stack/providers/impls/meta_reference/agents}/tools/safety.py (88%) rename {llama_toolchain/inference/meta_reference => llama_stack/providers/impls/meta_reference/inference}/__init__.py (100%) rename {llama_toolchain/inference/meta_reference => llama_stack/providers/impls/meta_reference/inference}/config.py (96%) rename {llama_toolchain/inference/meta_reference => llama_stack/providers/impls/meta_reference/inference}/generation.py (98%) rename {llama_toolchain/inference/meta_reference => llama_stack/providers/impls/meta_reference/inference}/inference.py (96%) rename {llama_toolchain/inference/meta_reference => llama_stack/providers/impls/meta_reference/inference}/model_parallel.py (100%) rename {llama_toolchain/inference/meta_reference => llama_stack/providers/impls/meta_reference/inference}/parallel_utils.py (100%) rename {llama_toolchain/evaluations/api => llama_stack/providers/impls/meta_reference/inference/quantization}/__init__.py (83%) rename {llama_toolchain => llama_stack/providers/impls/meta_reference}/inference/quantization/fp8_impls.py (100%) rename {llama_toolchain => llama_stack/providers/impls/meta_reference}/inference/quantization/loader.py (97%) rename {llama_toolchain/batch_inference/api => llama_stack/providers/impls/meta_reference/inference/quantization/scripts}/__init__.py (83%) rename {llama_toolchain => llama_stack/providers/impls/meta_reference}/inference/quantization/scripts/build_conda.sh (100%) rename {llama_toolchain => llama_stack/providers/impls/meta_reference}/inference/quantization/scripts/quantize_checkpoint.py (100%) rename {llama_toolchain => llama_stack/providers/impls/meta_reference}/inference/quantization/scripts/run_quantize_checkpoint.sh (100%) rename {llama_toolchain => llama_stack/providers/impls/meta_reference}/inference/quantization/test_fp8.py (100%) rename {llama_toolchain/memory/meta_reference/faiss => llama_stack/providers/impls/meta_reference/memory}/__init__.py (100%) rename {llama_toolchain/memory/meta_reference/faiss => llama_stack/providers/impls/meta_reference/memory}/config.py (100%) rename {llama_toolchain/memory/meta_reference/faiss => llama_stack/providers/impls/meta_reference/memory}/faiss.py (95%) rename {llama_toolchain/safety/meta_reference => llama_stack/providers/impls/meta_reference/safety}/__init__.py (100%) rename {llama_toolchain/safety/meta_reference => llama_stack/providers/impls/meta_reference/safety}/config.py (100%) rename {llama_toolchain/safety/meta_reference => llama_stack/providers/impls/meta_reference/safety}/safety.py (96%) rename {llama_toolchain/safety/meta_reference => llama_stack/providers/impls/meta_reference/safety}/shields/__init__.py (100%) rename {llama_toolchain/safety/meta_reference => llama_stack/providers/impls/meta_reference/safety}/shields/base.py (97%) rename {llama_toolchain/safety/meta_reference => llama_stack/providers/impls/meta_reference/safety}/shields/code_scanner.py (95%) create mode 100644 llama_stack/providers/impls/meta_reference/safety/shields/contrib/__init__.py rename {llama_toolchain/safety/meta_reference => llama_stack/providers/impls/meta_reference/safety}/shields/contrib/third_party_shield.py (93%) rename {llama_toolchain/safety/meta_reference => llama_stack/providers/impls/meta_reference/safety}/shields/llama_guard.py (99%) rename {llama_toolchain/safety/meta_reference => llama_stack/providers/impls/meta_reference/safety}/shields/prompt_guard.py (99%) rename {llama_toolchain/telemetry/console => llama_stack/providers/impls/meta_reference/telemetry}/__init__.py (100%) rename {llama_toolchain/telemetry/console => llama_stack/providers/impls/meta_reference/telemetry}/config.py (100%) rename {llama_toolchain/telemetry/console => llama_stack/providers/impls/meta_reference/telemetry}/console.py (97%) create mode 100644 llama_stack/providers/impls/sqlite/__init__.py create mode 100644 llama_stack/providers/registry/__init__.py rename llama_toolchain/agentic_system/providers.py => llama_stack/providers/registry/agents.py (69%) rename llama_toolchain/inference/providers.py => llama_stack/providers/registry/inference.py (64%) rename llama_toolchain/memory/providers.py => llama_stack/providers/registry/memory.py (58%) rename llama_toolchain/safety/providers.py => llama_stack/providers/registry/safety.py (69%) rename llama_toolchain/telemetry/providers.py => llama_stack/providers/registry/telemetry.py (58%) create mode 100644 llama_stack/providers/routers/__init__.py create mode 100644 llama_stack/providers/routers/memory/__init__.py create mode 100644 llama_stack/providers/routers/memory/memory.py create mode 100644 llama_stack/providers/utils/__init__.py create mode 100644 llama_stack/providers/utils/inference/__init__.py rename {llama_toolchain => llama_stack/providers/utils}/inference/prepare_messages.py (97%) create mode 100644 llama_stack/providers/utils/memory/__init__.py rename {llama_toolchain/memory/common => llama_stack/providers/utils/memory}/file_utils.py (100%) rename {llama_toolchain/memory/common => llama_stack/providers/utils/memory}/vector_store.py (98%) create mode 100644 llama_stack/providers/utils/telemetry/__init__.py rename {llama_toolchain => llama_stack/providers/utils}/telemetry/tracing.py (99%) delete mode 100644 llama_toolchain/agentic_system/__init__.py delete mode 100644 llama_toolchain/agentic_system/execute_with_custom_tools.py delete mode 100644 llama_toolchain/agentic_system/meta_reference/config.py delete mode 100644 llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml delete mode 100644 llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml delete mode 100644 llama_toolchain/core/configure.py delete mode 100644 llama_toolchain/core/dynamic.py delete mode 100644 llama_toolchain/inference/api/__init__.py delete mode 100644 llama_toolchain/memory/api/__init__.py delete mode 100644 llama_toolchain/post_training/api/__init__.py delete mode 100644 llama_toolchain/reward_scoring/api/__init__.py delete mode 100644 llama_toolchain/safety/api/__init__.py delete mode 100644 llama_toolchain/stack.py delete mode 100644 llama_toolchain/synthetic_data_generation/api/__init__.py delete mode 100644 llama_toolchain/telemetry/api/__init__.py delete mode 100644 llama_toolchain/tools/custom/datatypes.py diff --git a/MANIFEST.in b/MANIFEST.in index 4b76f85fe..e7c63fffd 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include requirements.txt -include llama_toolchain/data/*.yaml -include llama_toolchain/core/*.sh -include llama_toolchain/cli/scripts/*.sh +include llama_stack/distribution/*.sh +include llama_stack/cli/scripts/*.sh +include llama_stack/distribution/example_configs/conda/*.yaml +include llama_stack/distribution/example_configs/docker/*.yaml diff --git a/README.md b/README.md index 118af6e70..0e3efde71 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # llama-stack -[![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-toolchain)](https://pypi.org/project/llama-toolchain/) +[![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/) [![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/TZAAYNVtrU) This repository contains the specifications and implementations of the APIs which are part of the Llama Stack. @@ -42,7 +42,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c ## Installation -You can install this repository as a [package](https://pypi.org/project/llama-toolchain/) with `pip install llama-toolchain` +You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack` If you want to install from source: diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 8921fc941..f66c1e4cb 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -1,6 +1,6 @@ # Llama CLI Reference -The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-toolchain` package. +The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package. ### Subcommands 1. `download`: `llama` cli tools supports downloading the model from Meta or HuggingFace. @@ -276,16 +276,16 @@ The following command and specifications allows you to get started with building ``` llama stack build ``` -- You will be required to pass in a file path to the build.config file (e.g. `./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_toolchain/configs/distributions/` folder. +- You will be required to pass in a file path to the build.config file (e.g. `./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_stack/distribution/example_configs/` folder. The file will be of the contents ``` -$ cat ./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml +$ cat ./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml name: 8b-instruct distribution_spec: distribution_type: local - description: Use code from `llama_toolchain` itself to serve all llama stack APIs + description: Use code from `llama_stack` itself to serve all llama stack APIs docker_image: null providers: inference: meta-reference @@ -311,7 +311,7 @@ After this step is complete, a file named `8b-instruct-build.yaml` will be gener To specify a different API provider, we can change the `distribution_spec` in our `-build.yaml` config. For example, the following build spec allows you to build a distribution using TGI as the inference API provider. ``` -$ cat ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml +$ cat ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml name: local-tgi-conda-example distribution_spec: @@ -328,7 +328,7 @@ image_type: conda The following command allows you to build a distribution with TGI as the inference API provider, with the name `tgi`. ``` -llama stack build --config ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml --name tgi +llama stack build --config ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml --name tgi ``` We provide some example build configs to help you get started with building with different API providers. @@ -337,11 +337,11 @@ We provide some example build configs to help you get started with building with To build a docker image, simply change the `image_type` to `docker` in our `-build.yaml` file, and run `llama stack build --config -build.yaml`. ``` -$ cat ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml +$ cat ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml name: local-docker-example distribution_spec: - description: Use code from `llama_toolchain` itself to serve all llama stack APIs + description: Use code from `llama_stack` itself to serve all llama stack APIs docker_image: null providers: inference: meta-reference @@ -354,7 +354,7 @@ image_type: docker The following command allows you to build a Docker image with the name `docker-local` ``` -llama stack build --config ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml --name docker-local +llama stack build --config ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml --name docker-local Dockerfile created successfully in /tmp/tmp.I0ifS2c46A/DockerfileFROM python:3.10-slim WORKDIR /app @@ -480,9 +480,9 @@ This server is running a Llama model locally. Once the server is setup, we can test it with a client to see the example outputs. ``` cd /path/to/llama-stack -conda activate # any environment containing the llama-toolchain pip package will work +conda activate # any environment containing the llama-stack pip package will work -python -m llama_toolchain.inference.client localhost 5000 +python -m llama_stack.apis.inference.client localhost 5000 ``` This will run the chat completion client and query the distribution’s /inference/chat_completion API. @@ -500,7 +500,7 @@ You know what's even more hilarious? People like you who think they can just Goo Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by: ``` -python -m llama_toolchain.safety.client localhost 5000 +python -m llama_stack.safety.client localhost 5000 ``` You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo. diff --git a/docs/getting_started.md b/docs/getting_started.md index a312b8f33..e65581494 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -1,6 +1,6 @@ # Getting Started -The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-toolchain` package. +The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package. This guides allows you to quickly get started with building and running a Llama Stack server in < 5 minutes! @@ -9,7 +9,7 @@ This guides allows you to quickly get started with building and running a Llama **`llama stack build`** ``` -llama stack build --config ./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml --name my-local-llama-stack +llama stack build --config ./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml --name my-local-llama-stack ... ... Build spec configuration saved at ~/.llama/distributions/conda/my-local-llama-stack-build.yaml @@ -97,16 +97,16 @@ The following command and specifications allows you to get started with building ``` llama stack build ``` -- You will be required to pass in a file path to the build.config file (e.g. `./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_toolchain/configs/distributions/` folder. +- You will be required to pass in a file path to the build.config file (e.g. `./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml`). We provide some example build config files for configuring different types of distributions in the `./llama_stack/distribution/example_configs/` folder. The file will be of the contents ``` -$ cat ./llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml +$ cat ./llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml name: 8b-instruct distribution_spec: distribution_type: local - description: Use code from `llama_toolchain` itself to serve all llama stack APIs + description: Use code from `llama_stack` itself to serve all llama stack APIs docker_image: null providers: inference: meta-reference @@ -132,7 +132,7 @@ After this step is complete, a file named `8b-instruct-build.yaml` will be gener To specify a different API provider, we can change the `distribution_spec` in our `-build.yaml` config. For example, the following build spec allows you to build a distribution using TGI as the inference API provider. ``` -$ cat ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml +$ cat ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml name: local-tgi-conda-example distribution_spec: @@ -149,7 +149,7 @@ image_type: conda The following command allows you to build a distribution with TGI as the inference API provider, with the name `tgi`. ``` -llama stack build --config ./llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml --name tgi +llama stack build --config ./llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml --name tgi ``` We provide some example build configs to help you get started with building with different API providers. @@ -158,11 +158,11 @@ We provide some example build configs to help you get started with building with To build a docker image, simply change the `image_type` to `docker` in our `-build.yaml` file, and run `llama stack build --config -build.yaml`. ``` -$ cat ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml +$ cat ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml name: local-docker-example distribution_spec: - description: Use code from `llama_toolchain` itself to serve all llama stack APIs + description: Use code from `llama_stack` itself to serve all llama stack APIs docker_image: null providers: inference: meta-reference @@ -175,7 +175,7 @@ image_type: docker The following command allows you to build a Docker image with the name `docker-local` ``` -llama stack build --config ./llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml --name docker-local +llama stack build --config ./llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml --name docker-local Dockerfile created successfully in /tmp/tmp.I0ifS2c46A/DockerfileFROM python:3.10-slim WORKDIR /app @@ -294,9 +294,9 @@ This server is running a Llama model locally. Once the server is setup, we can test it with a client to see the example outputs. ``` cd /path/to/llama-stack -conda activate # any environment containing the llama-toolchain pip package will work +conda activate # any environment containing the llama-stack pip package will work -python -m llama_toolchain.inference.client localhost 5000 +python -m llama_stack.apis.inference.client localhost 5000 ``` This will run the chat completion client and query the distribution’s /inference/chat_completion API. @@ -314,7 +314,7 @@ You know what's even more hilarious? People like you who think they can just Goo Similarly you can test safety (if you configured llama-guard and/or prompt-guard shields) by: ``` -python -m llama_toolchain.safety.client localhost 5000 +python -m llama_stack.apis.safety.client localhost 5000 ``` You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/sdk_examples) repo. diff --git a/llama_toolchain/__init__.py b/llama_stack/__init__.py similarity index 100% rename from llama_toolchain/__init__.py rename to llama_stack/__init__.py diff --git a/llama_toolchain/batch_inference/__init__.py b/llama_stack/apis/__init__.py similarity index 100% rename from llama_toolchain/batch_inference/__init__.py rename to llama_stack/apis/__init__.py diff --git a/llama_stack/apis/agents/__init__.py b/llama_stack/apis/agents/__init__.py new file mode 100644 index 000000000..ab203b6cd --- /dev/null +++ b/llama_stack/apis/agents/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .agents import * # noqa: F401 F403 diff --git a/llama_toolchain/agentic_system/api/api.py b/llama_stack/apis/agents/agents.py similarity index 68% rename from llama_toolchain/agentic_system/api/api.py rename to llama_stack/apis/agents/agents.py index 95af3727b..5cc9ce242 100644 --- a/llama_toolchain/agentic_system/api/api.py +++ b/llama_stack/apis/agents/agents.py @@ -14,10 +14,10 @@ from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.common.deployment_types import * # noqa: F403 -from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.safety.api import * # noqa: F403 -from llama_toolchain.memory.api import * # noqa: F403 +from llama_stack.apis.common.deployment_types import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.memory import * # noqa: F403 @json_schema_type @@ -26,7 +26,7 @@ class Attachment(BaseModel): mime_type: str -class AgenticSystemTool(Enum): +class AgentTool(Enum): brave_search = "brave_search" wolfram_alpha = "wolfram_alpha" photogen = "photogen" @@ -50,41 +50,35 @@ class SearchEngineType(Enum): class SearchToolDefinition(ToolDefinitionCommon): # NOTE: brave_search is just a placeholder since model always uses # brave_search as tool call name - type: Literal[AgenticSystemTool.brave_search.value] = ( - AgenticSystemTool.brave_search.value - ) + type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value + api_key: str engine: SearchEngineType = SearchEngineType.brave remote_execution: Optional[RestAPIExecutionConfig] = None @json_schema_type class WolframAlphaToolDefinition(ToolDefinitionCommon): - type: Literal[AgenticSystemTool.wolfram_alpha.value] = ( - AgenticSystemTool.wolfram_alpha.value - ) + type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value + api_key: str remote_execution: Optional[RestAPIExecutionConfig] = None @json_schema_type class PhotogenToolDefinition(ToolDefinitionCommon): - type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value + type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value remote_execution: Optional[RestAPIExecutionConfig] = None @json_schema_type class CodeInterpreterToolDefinition(ToolDefinitionCommon): - type: Literal[AgenticSystemTool.code_interpreter.value] = ( - AgenticSystemTool.code_interpreter.value - ) + type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value enable_inline_code_execution: bool = True remote_execution: Optional[RestAPIExecutionConfig] = None @json_schema_type class FunctionCallToolDefinition(ToolDefinitionCommon): - type: Literal[AgenticSystemTool.function_call.value] = ( - AgenticSystemTool.function_call.value - ) + type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value function_name: str description: str parameters: Dict[str, ToolParamDefinition] @@ -95,30 +89,30 @@ class _MemoryBankConfigCommon(BaseModel): bank_id: str -class AgenticSystemVectorMemoryBankConfig(_MemoryBankConfigCommon): +class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon): type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value -class AgenticSystemKeyValueMemoryBankConfig(_MemoryBankConfigCommon): +class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon): type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value keys: List[str] # what keys to focus on -class AgenticSystemKeywordMemoryBankConfig(_MemoryBankConfigCommon): +class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon): type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value -class AgenticSystemGraphMemoryBankConfig(_MemoryBankConfigCommon): +class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon): type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value entities: List[str] # what entities to focus on MemoryBankConfig = Annotated[ Union[ - AgenticSystemVectorMemoryBankConfig, - AgenticSystemKeyValueMemoryBankConfig, - AgenticSystemKeywordMemoryBankConfig, - AgenticSystemGraphMemoryBankConfig, + AgentVectorMemoryBankConfig, + AgentKeyValueMemoryBankConfig, + AgentKeywordMemoryBankConfig, + AgentGraphMemoryBankConfig, ], Field(discriminator="type"), ] @@ -158,7 +152,7 @@ MemoryQueryGeneratorConfig = Annotated[ class MemoryToolDefinition(ToolDefinitionCommon): - type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value + type: Literal[AgentTool.memory.value] = AgentTool.memory.value memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) # This config defines how a query is generated using the messages # for memory bank retrieval. @@ -169,7 +163,7 @@ class MemoryToolDefinition(ToolDefinitionCommon): max_chunks: int = 10 -AgenticSystemToolDefinition = Annotated[ +AgentToolDefinition = Annotated[ Union[ SearchToolDefinition, WolframAlphaToolDefinition, @@ -275,7 +269,7 @@ class AgentConfigCommon(BaseModel): input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) - tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list) + tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json @@ -292,7 +286,7 @@ class AgentConfigOverridablePerTurn(AgentConfigCommon): instructions: Optional[str] = None -class AgenticSystemTurnResponseEventType(Enum): +class AgentTurnResponseEventType(Enum): step_start = "step_start" step_complete = "step_complete" step_progress = "step_progress" @@ -302,9 +296,9 @@ class AgenticSystemTurnResponseEventType(Enum): @json_schema_type -class AgenticSystemTurnResponseStepStartPayload(BaseModel): - event_type: Literal[AgenticSystemTurnResponseEventType.step_start.value] = ( - AgenticSystemTurnResponseEventType.step_start.value +class AgentTurnResponseStepStartPayload(BaseModel): + event_type: Literal[AgentTurnResponseEventType.step_start.value] = ( + AgentTurnResponseEventType.step_start.value ) step_type: StepType step_id: str @@ -312,20 +306,20 @@ class AgenticSystemTurnResponseStepStartPayload(BaseModel): @json_schema_type -class AgenticSystemTurnResponseStepCompletePayload(BaseModel): - event_type: Literal[AgenticSystemTurnResponseEventType.step_complete.value] = ( - AgenticSystemTurnResponseEventType.step_complete.value +class AgentTurnResponseStepCompletePayload(BaseModel): + event_type: Literal[AgentTurnResponseEventType.step_complete.value] = ( + AgentTurnResponseEventType.step_complete.value ) step_type: StepType step_details: Step @json_schema_type -class AgenticSystemTurnResponseStepProgressPayload(BaseModel): +class AgentTurnResponseStepProgressPayload(BaseModel): model_config = ConfigDict(protected_namespaces=()) - event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = ( - AgenticSystemTurnResponseEventType.step_progress.value + event_type: Literal[AgentTurnResponseEventType.step_progress.value] = ( + AgentTurnResponseEventType.step_progress.value ) step_type: StepType step_id: str @@ -336,49 +330,49 @@ class AgenticSystemTurnResponseStepProgressPayload(BaseModel): @json_schema_type -class AgenticSystemTurnResponseTurnStartPayload(BaseModel): - event_type: Literal[AgenticSystemTurnResponseEventType.turn_start.value] = ( - AgenticSystemTurnResponseEventType.turn_start.value +class AgentTurnResponseTurnStartPayload(BaseModel): + event_type: Literal[AgentTurnResponseEventType.turn_start.value] = ( + AgentTurnResponseEventType.turn_start.value ) turn_id: str @json_schema_type -class AgenticSystemTurnResponseTurnCompletePayload(BaseModel): - event_type: Literal[AgenticSystemTurnResponseEventType.turn_complete.value] = ( - AgenticSystemTurnResponseEventType.turn_complete.value +class AgentTurnResponseTurnCompletePayload(BaseModel): + event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = ( + AgentTurnResponseEventType.turn_complete.value ) turn: Turn @json_schema_type -class AgenticSystemTurnResponseEvent(BaseModel): +class AgentTurnResponseEvent(BaseModel): """Streamed agent execution response.""" payload: Annotated[ Union[ - AgenticSystemTurnResponseStepStartPayload, - AgenticSystemTurnResponseStepProgressPayload, - AgenticSystemTurnResponseStepCompletePayload, - AgenticSystemTurnResponseTurnStartPayload, - AgenticSystemTurnResponseTurnCompletePayload, + AgentTurnResponseStepStartPayload, + AgentTurnResponseStepProgressPayload, + AgentTurnResponseStepCompletePayload, + AgentTurnResponseTurnStartPayload, + AgentTurnResponseTurnCompletePayload, ], Field(discriminator="event_type"), ] @json_schema_type -class AgenticSystemCreateResponse(BaseModel): +class AgentCreateResponse(BaseModel): agent_id: str @json_schema_type -class AgenticSystemSessionCreateResponse(BaseModel): +class AgentSessionCreateResponse(BaseModel): session_id: str @json_schema_type -class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn): +class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): agent_id: str session_id: str @@ -397,24 +391,24 @@ class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn): @json_schema_type -class AgenticSystemTurnResponseStreamChunk(BaseModel): - event: AgenticSystemTurnResponseEvent +class AgentTurnResponseStreamChunk(BaseModel): + event: AgentTurnResponseEvent @json_schema_type -class AgenticSystemStepResponse(BaseModel): +class AgentStepResponse(BaseModel): step: Step -class AgenticSystem(Protocol): - @webmethod(route="/agentic_system/create") - async def create_agentic_system( +class Agents(Protocol): + @webmethod(route="/agents/create") + async def create_agent( self, agent_config: AgentConfig, - ) -> AgenticSystemCreateResponse: ... + ) -> AgentCreateResponse: ... - @webmethod(route="/agentic_system/turn/create") - async def create_agentic_system_turn( + @webmethod(route="/agents/turn/create") + async def create_agent_turn( self, agent_id: str, session_id: str, @@ -426,42 +420,40 @@ class AgenticSystem(Protocol): ], attachments: Optional[List[Attachment]] = None, stream: Optional[bool] = False, - ) -> AgenticSystemTurnResponseStreamChunk: ... + ) -> AgentTurnResponseStreamChunk: ... - @webmethod(route="/agentic_system/turn/get") - async def get_agentic_system_turn( + @webmethod(route="/agents/turn/get") + async def get_agents_turn( self, agent_id: str, turn_id: str, ) -> Turn: ... - @webmethod(route="/agentic_system/step/get") - async def get_agentic_system_step( + @webmethod(route="/agents/step/get") + async def get_agents_step( self, agent_id: str, turn_id: str, step_id: str - ) -> AgenticSystemStepResponse: ... + ) -> AgentStepResponse: ... - @webmethod(route="/agentic_system/session/create") - async def create_agentic_system_session( + @webmethod(route="/agents/session/create") + async def create_agent_session( self, agent_id: str, session_name: str, - ) -> AgenticSystemSessionCreateResponse: ... + ) -> AgentSessionCreateResponse: ... - @webmethod(route="/agentic_system/session/get") - async def get_agentic_system_session( + @webmethod(route="/agents/session/get") + async def get_agents_session( self, agent_id: str, session_id: str, turn_ids: Optional[List[str]] = None, ) -> Session: ... - @webmethod(route="/agentic_system/session/delete") - async def delete_agentic_system_session( - self, agent_id: str, session_id: str - ) -> None: ... + @webmethod(route="/agents/session/delete") + async def delete_agents_session(self, agent_id: str, session_id: str) -> None: ... - @webmethod(route="/agentic_system/delete") - async def delete_agentic_system( + @webmethod(route="/agents/delete") + async def delete_agents( self, agent_id: str, ) -> None: ... diff --git a/llama_toolchain/agentic_system/client.py b/llama_stack/apis/agents/client.py similarity index 79% rename from llama_toolchain/agentic_system/client.py rename to llama_stack/apis/agents/client.py index 52cf0dee2..c5cba3541 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_stack/apis/agents/client.py @@ -6,56 +6,58 @@ import asyncio import json +import os from typing import AsyncGenerator import fire - import httpx +from dotenv import load_dotenv from pydantic import BaseModel from termcolor import cprint from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.core.datatypes import RemoteProviderConfig +from llama_stack.distribution.datatypes import RemoteProviderConfig -from .api import * # noqa: F403 +from .agents import * # noqa: F403 from .event_logger import EventLogger +load_dotenv() + + async def get_client_impl(config: RemoteProviderConfig, _deps): - return AgenticSystemClient(config.url) + return AgentsClient(config.url) def encodable_dict(d: BaseModel): return json.loads(d.json()) -class AgenticSystemClient(AgenticSystem): +class AgentsClient(Agents): def __init__(self, base_url: str): self.base_url = base_url - async def create_agentic_system( - self, agent_config: AgentConfig - ) -> AgenticSystemCreateResponse: + async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse: async with httpx.AsyncClient() as client: response = await client.post( - f"{self.base_url}/agentic_system/create", + f"{self.base_url}/agents/create", json={ "agent_config": encodable_dict(agent_config), }, headers={"Content-Type": "application/json"}, ) response.raise_for_status() - return AgenticSystemCreateResponse(**response.json()) + return AgentCreateResponse(**response.json()) - async def create_agentic_system_session( + async def create_agent_session( self, agent_id: str, session_name: str, - ) -> AgenticSystemSessionCreateResponse: + ) -> AgentSessionCreateResponse: async with httpx.AsyncClient() as client: response = await client.post( - f"{self.base_url}/agentic_system/session/create", + f"{self.base_url}/agents/session/create", json={ "agent_id": agent_id, "session_name": session_name, @@ -63,16 +65,16 @@ class AgenticSystemClient(AgenticSystem): headers={"Content-Type": "application/json"}, ) response.raise_for_status() - return AgenticSystemSessionCreateResponse(**response.json()) + return AgentSessionCreateResponse(**response.json()) - async def create_agentic_system_turn( + async def create_agent_turn( self, - request: AgenticSystemTurnCreateRequest, + request: AgentTurnCreateRequest, ) -> AsyncGenerator: async with httpx.AsyncClient() as client: async with client.stream( "POST", - f"{self.base_url}/agentic_system/turn/create", + f"{self.base_url}/agents/turn/create", json=encodable_dict(request), headers={"Content-Type": "application/json"}, timeout=20, @@ -86,7 +88,7 @@ class AgenticSystemClient(AgenticSystem): cprint(data, "red") continue - yield AgenticSystemTurnResponseStreamChunk(**jdata) + yield AgentTurnResponseStreamChunk(**jdata) except Exception as e: print(data) print(f"Error with parsing or validation: {e}") @@ -102,16 +104,16 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None): tool_prompt_format=ToolPromptFormat.function_tag, ) - create_response = await api.create_agentic_system(agent_config) - session_response = await api.create_agentic_system_session( + create_response = await api.create_agent(agent_config) + session_response = await api.create_agent_session( agent_id=create_response.agent_id, session_name="test_session", ) for content in user_prompts: cprint(f"User> {content}", color="white", attrs=["bold"]) - iterator = api.create_agentic_system_turn( - AgenticSystemTurnCreateRequest( + iterator = api.create_agent_turn( + AgentTurnCreateRequest( agent_id=create_response.agent_id, session_id=session_response.session_id, messages=[ @@ -128,11 +130,14 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None): async def run_main(host: str, port: int): - api = AgenticSystemClient(f"http://{host}:{port}") + api = AgentsClient(f"http://{host}:{port}") tool_definitions = [ - SearchToolDefinition(engine=SearchEngineType.bing), - WolframAlphaToolDefinition(), + SearchToolDefinition( + engine=SearchEngineType.brave, + api_key=os.getenv("BRAVE_SEARCH_API_KEY"), + ), + WolframAlphaToolDefinition(api_key=os.getenv("WOLFRAM_ALPHA_API_KEY")), CodeInterpreterToolDefinition(), ] tool_definitions += [ @@ -165,7 +170,7 @@ async def run_main(host: str, port: int): async def run_rag(host: str, port: int): - api = AgenticSystemClient(f"http://{host}:{port}") + api = AgentsClient(f"http://{host}:{port}") urls = [ "memory_optimizations.rst", @@ -186,7 +191,7 @@ async def run_rag(host: str, port: int): ] # Alternatively, you can pre-populate the memory bank with documents for example, - # using `llama_toolchain.memory.client`. Then you can grab the bank_id + # using `llama_stack.memory.client`. Then you can grab the bank_id # from the output of that run. tool_definitions = [ MemoryToolDefinition( diff --git a/llama_toolchain/agentic_system/event_logger.py b/llama_stack/apis/agents/event_logger.py similarity index 97% rename from llama_toolchain/agentic_system/event_logger.py rename to llama_stack/apis/agents/event_logger.py index 3d15ee239..9cbd1fbd2 100644 --- a/llama_toolchain/agentic_system/event_logger.py +++ b/llama_stack/apis/agents/event_logger.py @@ -9,12 +9,9 @@ from typing import Optional from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tool_utils import ToolUtils -from termcolor import cprint +from llama_stack.apis.agents import AgentTurnResponseEventType, StepType -from llama_toolchain.agentic_system.api import ( - AgenticSystemTurnResponseEventType, - StepType, -) +from termcolor import cprint class LogEvent: @@ -40,7 +37,7 @@ class LogEvent: cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush) -EventType = AgenticSystemTurnResponseEventType +EventType = AgentTurnResponseEventType class EventLogger: diff --git a/llama_stack/apis/batch_inference/__init__.py b/llama_stack/apis/batch_inference/__init__.py new file mode 100644 index 000000000..3249475ee --- /dev/null +++ b/llama_stack/apis/batch_inference/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .batch_inference import * # noqa: F401 F403 diff --git a/llama_toolchain/batch_inference/api/api.py b/llama_stack/apis/batch_inference/batch_inference.py similarity index 97% rename from llama_toolchain/batch_inference/api/api.py rename to llama_stack/apis/batch_inference/batch_inference.py index 3d67120dd..0c3132812 100644 --- a/llama_toolchain/batch_inference/api/api.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.inference.api import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 @json_schema_type diff --git a/llama_toolchain/cli/__init__.py b/llama_stack/apis/common/__init__.py similarity index 100% rename from llama_toolchain/cli/__init__.py rename to llama_stack/apis/common/__init__.py diff --git a/llama_toolchain/common/deployment_types.py b/llama_stack/apis/common/deployment_types.py similarity index 100% rename from llama_toolchain/common/deployment_types.py rename to llama_stack/apis/common/deployment_types.py diff --git a/llama_toolchain/common/training_types.py b/llama_stack/apis/common/training_types.py similarity index 100% rename from llama_toolchain/common/training_types.py rename to llama_stack/apis/common/training_types.py diff --git a/llama_stack/apis/dataset/__init__.py b/llama_stack/apis/dataset/__init__.py new file mode 100644 index 000000000..33557a0ab --- /dev/null +++ b/llama_stack/apis/dataset/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .dataset import * # noqa: F401 F403 diff --git a/llama_toolchain/dataset/api/api.py b/llama_stack/apis/dataset/dataset.py similarity index 100% rename from llama_toolchain/dataset/api/api.py rename to llama_stack/apis/dataset/dataset.py diff --git a/llama_stack/apis/evals/__init__.py b/llama_stack/apis/evals/__init__.py new file mode 100644 index 000000000..d21b97d0a --- /dev/null +++ b/llama_stack/apis/evals/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .evals import * # noqa: F401 F403 diff --git a/llama_toolchain/evaluations/api/api.py b/llama_stack/apis/evals/evals.py similarity index 95% rename from llama_toolchain/evaluations/api/api.py rename to llama_stack/apis/evals/evals.py index 898dc2822..0be2243ab 100644 --- a/llama_toolchain/evaluations/api/api.py +++ b/llama_stack/apis/evals/evals.py @@ -12,8 +12,8 @@ from llama_models.schema_utils import webmethod from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.dataset.api import * # noqa: F403 -from llama_toolchain.common.training_types import * # noqa: F403 +from llama_stack.apis.dataset import * # noqa: F403 +from llama_stack.apis.common.training_types import * # noqa: F403 class TextGenerationMetric(Enum): diff --git a/llama_stack/apis/inference/__init__.py b/llama_stack/apis/inference/__init__.py new file mode 100644 index 000000000..f9f77f769 --- /dev/null +++ b/llama_stack/apis/inference/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .inference import * # noqa: F401 F403 diff --git a/llama_toolchain/inference/client.py b/llama_stack/apis/inference/client.py similarity index 97% rename from llama_toolchain/inference/client.py rename to llama_stack/apis/inference/client.py index c57433a8f..f5321c628 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -11,11 +11,13 @@ from typing import Any, AsyncGenerator import fire import httpx -from llama_toolchain.core.datatypes import RemoteProviderConfig +from llama_stack.distribution.datatypes import RemoteProviderConfig from pydantic import BaseModel from termcolor import cprint -from .api import ( +from .event_logger import EventLogger + +from .inference import ( ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseStreamChunk, @@ -23,7 +25,6 @@ from .api import ( Inference, UserMessage, ) -from .event_logger import EventLogger async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference: diff --git a/llama_toolchain/inference/event_logger.py b/llama_stack/apis/inference/event_logger.py similarity index 97% rename from llama_toolchain/inference/event_logger.py rename to llama_stack/apis/inference/event_logger.py index 248ceae27..c64ffb6bd 100644 --- a/llama_toolchain/inference/event_logger.py +++ b/llama_stack/apis/inference/event_logger.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_toolchain.inference.api import ( +from llama_stack.apis.inference import ( ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, ) diff --git a/llama_toolchain/inference/api/api.py b/llama_stack/apis/inference/inference.py similarity index 100% rename from llama_toolchain/inference/api/api.py rename to llama_stack/apis/inference/inference.py diff --git a/llama_stack/apis/memory/__init__.py b/llama_stack/apis/memory/__init__.py new file mode 100644 index 000000000..260862228 --- /dev/null +++ b/llama_stack/apis/memory/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .memory import * # noqa: F401 F403 diff --git a/llama_toolchain/memory/client.py b/llama_stack/apis/memory/client.py similarity index 97% rename from llama_toolchain/memory/client.py rename to llama_stack/apis/memory/client.py index 5f74219da..d2845326b 100644 --- a/llama_toolchain/memory/client.py +++ b/llama_stack/apis/memory/client.py @@ -6,17 +6,18 @@ import asyncio import json +import os from pathlib import Path from typing import Any, Dict, List, Optional import fire import httpx + +from llama_stack.distribution.datatypes import RemoteProviderConfig from termcolor import cprint -from llama_toolchain.core.datatypes import RemoteProviderConfig - -from .api import * # noqa: F403 +from .memory import * # noqa: F403 from .common.file_utils import data_url_from_file diff --git a/llama_toolchain/memory/api/api.py b/llama_stack/apis/memory/memory.py similarity index 100% rename from llama_toolchain/memory/api/api.py rename to llama_stack/apis/memory/memory.py diff --git a/llama_stack/apis/models/__init__.py b/llama_stack/apis/models/__init__.py new file mode 100644 index 000000000..410d8d1f9 --- /dev/null +++ b/llama_stack/apis/models/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .models import * # noqa: F401 F403 diff --git a/llama_toolchain/models/api/api.py b/llama_stack/apis/models/models.py similarity index 100% rename from llama_toolchain/models/api/api.py rename to llama_stack/apis/models/models.py diff --git a/llama_stack/apis/post_training/__init__.py b/llama_stack/apis/post_training/__init__.py new file mode 100644 index 000000000..7129c4abd --- /dev/null +++ b/llama_stack/apis/post_training/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .post_training import * # noqa: F401 F403 diff --git a/llama_toolchain/post_training/api/api.py b/llama_stack/apis/post_training/post_training.py similarity index 97% rename from llama_toolchain/post_training/api/api.py rename to llama_stack/apis/post_training/post_training.py index 378515f83..d943f48b2 100644 --- a/llama_toolchain/post_training/api/api.py +++ b/llama_stack/apis/post_training/post_training.py @@ -14,8 +14,8 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.dataset.api import * # noqa: F403 -from llama_toolchain.common.training_types import * # noqa: F403 +from llama_stack.apis.dataset import * # noqa: F403 +from llama_stack.apis.common.training_types import * # noqa: F403 class OptimizerType(Enum): diff --git a/llama_stack/apis/reward_scoring/__init__.py b/llama_stack/apis/reward_scoring/__init__.py new file mode 100644 index 000000000..7ea62c241 --- /dev/null +++ b/llama_stack/apis/reward_scoring/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .reward_scoring import * # noqa: F401 F403 diff --git a/llama_toolchain/reward_scoring/api/api.py b/llama_stack/apis/reward_scoring/reward_scoring.py similarity index 100% rename from llama_toolchain/reward_scoring/api/api.py rename to llama_stack/apis/reward_scoring/reward_scoring.py diff --git a/llama_stack/apis/safety/__init__.py b/llama_stack/apis/safety/__init__.py new file mode 100644 index 000000000..dc3fe90b4 --- /dev/null +++ b/llama_stack/apis/safety/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .safety import * # noqa: F401 F403 diff --git a/llama_toolchain/safety/client.py b/llama_stack/apis/safety/client.py similarity index 95% rename from llama_toolchain/safety/client.py rename to llama_stack/apis/safety/client.py index 26a9813b3..b7472686a 100644 --- a/llama_toolchain/safety/client.py +++ b/llama_stack/apis/safety/client.py @@ -14,11 +14,11 @@ import httpx from llama_models.llama3.api.datatypes import UserMessage -from llama_toolchain.core.datatypes import RemoteProviderConfig +from llama_stack.distribution.datatypes import RemoteProviderConfig from pydantic import BaseModel from termcolor import cprint -from .api import * # noqa: F403 +from .safety import * # noqa: F403 async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety: diff --git a/llama_toolchain/safety/api/api.py b/llama_stack/apis/safety/safety.py similarity index 96% rename from llama_toolchain/safety/api/api.py rename to llama_stack/apis/safety/safety.py index 631cfa992..2733dde73 100644 --- a/llama_toolchain/safety/api/api.py +++ b/llama_stack/apis/safety/safety.py @@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, validator from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.common.deployment_types import RestAPIExecutionConfig +from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig @json_schema_type diff --git a/llama_stack/apis/stack.py b/llama_stack/apis/stack.py new file mode 100644 index 000000000..f6c66d23b --- /dev/null +++ b/llama_stack/apis/stack.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.agents import * # noqa: F403 +from llama_stack.apis.dataset import * # noqa: F403 +from llama_stack.apis.evals import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.batch_inference import * # noqa: F403 +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.telemetry import * # noqa: F403 +from llama_stack.apis.post_training import * # noqa: F403 +from llama_stack.apis.reward_scoring import * # noqa: F403 +from llama_stack.apis.synthetic_data_generation import * # noqa: F403 +from llama_stack.apis.safety import * # noqa: F403 + + +class LlamaStack( + Inference, + BatchInference, + Agents, + RewardScoring, + Safety, + SyntheticDataGeneration, + Datasets, + Telemetry, + PostTraining, + Memory, + Evaluations, +): + pass diff --git a/llama_stack/apis/synthetic_data_generation/__init__.py b/llama_stack/apis/synthetic_data_generation/__init__.py new file mode 100644 index 000000000..cfdec76ce --- /dev/null +++ b/llama_stack/apis/synthetic_data_generation/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .synthetic_data_generation import * # noqa: F401 F403 diff --git a/llama_toolchain/synthetic_data_generation/api/api.py b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py similarity index 96% rename from llama_toolchain/synthetic_data_generation/api/api.py rename to llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py index 9a6c487af..60c756128 100644 --- a/llama_toolchain/synthetic_data_generation/api/api.py +++ b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py @@ -13,7 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.reward_scoring.api import * # noqa: F403 +from llama_stack.apis.reward_scoring import * # noqa: F403 class FilteringFunction(Enum): diff --git a/llama_stack/apis/telemetry/__init__.py b/llama_stack/apis/telemetry/__init__.py new file mode 100644 index 000000000..6a111dc9e --- /dev/null +++ b/llama_stack/apis/telemetry/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .telemetry import * # noqa: F401 F403 diff --git a/llama_toolchain/telemetry/api/api.py b/llama_stack/apis/telemetry/telemetry.py similarity index 100% rename from llama_toolchain/telemetry/api/api.py rename to llama_stack/apis/telemetry/telemetry.py diff --git a/llama_toolchain/cli/scripts/__init__.py b/llama_stack/cli/__init__.py similarity index 100% rename from llama_toolchain/cli/scripts/__init__.py rename to llama_stack/cli/__init__.py diff --git a/llama_toolchain/cli/download.py b/llama_stack/cli/download.py similarity index 97% rename from llama_toolchain/cli/download.py rename to llama_stack/cli/download.py index 1bfa89fc6..618036665 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_stack/cli/download.py @@ -20,7 +20,7 @@ from pydantic import BaseModel from termcolor import cprint -from llama_toolchain.cli.subcommand import Subcommand +from llama_stack.cli.subcommand import Subcommand class Download(Subcommand): @@ -92,7 +92,7 @@ def _hf_download( from huggingface_hub import snapshot_download from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError - from llama_toolchain.common.model_utils import model_local_dir + from llama_stack.distribution.utils.model_utils import model_local_dir repo_id = model.huggingface_repo if repo_id is None: @@ -106,7 +106,7 @@ def _hf_download( local_dir=output_dir, ignore_patterns=ignore_patterns, token=hf_token, - library_name="llama-toolchain", + library_name="llama-stack", ) except GatedRepoError: parser.error( @@ -126,7 +126,7 @@ def _hf_download( def _meta_download(model: "Model", meta_url: str): from llama_models.sku_list import llama_meta_net_info - from llama_toolchain.common.model_utils import model_local_dir + from llama_stack.distribution.utils.model_utils import model_local_dir output_dir = Path(model_local_dir(model.descriptor())) os.makedirs(output_dir, exist_ok=True) @@ -188,7 +188,7 @@ class Manifest(BaseModel): def _download_from_manifest(manifest_file: str): - from llama_toolchain.common.model_utils import model_local_dir + from llama_stack.distribution.utils.model_utils import model_local_dir with open(manifest_file, "r") as f: d = json.load(f) diff --git a/llama_toolchain/cli/llama.py b/llama_stack/cli/llama.py similarity index 80% rename from llama_toolchain/cli/llama.py rename to llama_stack/cli/llama.py index 9a5530c0c..8ca82db81 100644 --- a/llama_toolchain/cli/llama.py +++ b/llama_stack/cli/llama.py @@ -31,16 +31,6 @@ class LlamaCLIParser: ModelParser.create(subparsers) StackParser.create(subparsers) - # Import sub-commands from agentic_system if they exist - try: - from llama_agentic_system.cli.subcommand_modules import SUBCOMMAND_MODULES - - for module in SUBCOMMAND_MODULES: - module.create(subparsers) - - except ImportError: - pass - def parse_args(self) -> argparse.Namespace: return self.parser.parse_args() diff --git a/llama_toolchain/cli/model/__init__.py b/llama_stack/cli/model/__init__.py similarity index 100% rename from llama_toolchain/cli/model/__init__.py rename to llama_stack/cli/model/__init__.py diff --git a/llama_toolchain/cli/model/describe.py b/llama_stack/cli/model/describe.py similarity index 93% rename from llama_toolchain/cli/model/describe.py rename to llama_stack/cli/model/describe.py index 683995f7b..b100f7544 100644 --- a/llama_toolchain/cli/model/describe.py +++ b/llama_stack/cli/model/describe.py @@ -9,12 +9,12 @@ import json from llama_models.sku_list import resolve_model -from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.cli.table import print_table -from llama_toolchain.common.serialize import EnumEncoder - from termcolor import colored +from llama_stack.cli.subcommand import Subcommand +from llama_stack.cli.table import print_table +from llama_stack.distribution.utils.serialize import EnumEncoder + class ModelDescribe(Subcommand): """Show details about a model""" diff --git a/llama_toolchain/cli/model/download.py b/llama_stack/cli/model/download.py similarity index 83% rename from llama_toolchain/cli/model/download.py rename to llama_stack/cli/model/download.py index ac3c791b4..a3b8f7796 100644 --- a/llama_toolchain/cli/model/download.py +++ b/llama_stack/cli/model/download.py @@ -6,7 +6,7 @@ import argparse -from llama_toolchain.cli.subcommand import Subcommand +from llama_stack.cli.subcommand import Subcommand class ModelDownload(Subcommand): @@ -19,6 +19,6 @@ class ModelDownload(Subcommand): formatter_class=argparse.RawTextHelpFormatter, ) - from llama_toolchain.cli.download import setup_download_parser + from llama_stack.cli.download import setup_download_parser setup_download_parser(self.parser) diff --git a/llama_toolchain/cli/model/list.py b/llama_stack/cli/model/list.py similarity index 94% rename from llama_toolchain/cli/model/list.py rename to llama_stack/cli/model/list.py index f989260ab..977590d7a 100644 --- a/llama_toolchain/cli/model/list.py +++ b/llama_stack/cli/model/list.py @@ -8,8 +8,8 @@ import argparse from llama_models.sku_list import all_registered_models -from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.cli.table import print_table +from llama_stack.cli.subcommand import Subcommand +from llama_stack.cli.table import print_table class ModelList(Subcommand): diff --git a/llama_toolchain/cli/model/model.py b/llama_stack/cli/model/model.py similarity index 73% rename from llama_toolchain/cli/model/model.py rename to llama_stack/cli/model/model.py index 9a14450ad..c222c1d63 100644 --- a/llama_toolchain/cli/model/model.py +++ b/llama_stack/cli/model/model.py @@ -6,12 +6,12 @@ import argparse -from llama_toolchain.cli.model.describe import ModelDescribe -from llama_toolchain.cli.model.download import ModelDownload -from llama_toolchain.cli.model.list import ModelList -from llama_toolchain.cli.model.template import ModelTemplate +from llama_stack.cli.model.describe import ModelDescribe +from llama_stack.cli.model.download import ModelDownload +from llama_stack.cli.model.list import ModelList +from llama_stack.cli.model.template import ModelTemplate -from llama_toolchain.cli.subcommand import Subcommand +from llama_stack.cli.subcommand import Subcommand class ModelParser(Subcommand): diff --git a/llama_toolchain/cli/model/template.py b/llama_stack/cli/model/template.py similarity index 97% rename from llama_toolchain/cli/model/template.py rename to llama_stack/cli/model/template.py index 2776d9703..d828660bb 100644 --- a/llama_toolchain/cli/model/template.py +++ b/llama_stack/cli/model/template.py @@ -9,7 +9,7 @@ import textwrap from termcolor import colored -from llama_toolchain.cli.subcommand import Subcommand +from llama_stack.cli.subcommand import Subcommand class ModelTemplate(Subcommand): @@ -75,7 +75,7 @@ class ModelTemplate(Subcommand): render_jinja_template, ) - from llama_toolchain.cli.table import print_table + from llama_stack.cli.table import print_table if args.name: tool_prompt_format = self._prompt_type(args.format) diff --git a/llama_toolchain/common/__init__.py b/llama_stack/cli/scripts/__init__.py similarity index 100% rename from llama_toolchain/common/__init__.py rename to llama_stack/cli/scripts/__init__.py diff --git a/llama_toolchain/cli/scripts/install-wheel-from-presigned.sh b/llama_stack/cli/scripts/install-wheel-from-presigned.sh similarity index 100% rename from llama_toolchain/cli/scripts/install-wheel-from-presigned.sh rename to llama_stack/cli/scripts/install-wheel-from-presigned.sh diff --git a/llama_toolchain/cli/scripts/run.py b/llama_stack/cli/scripts/run.py similarity index 100% rename from llama_toolchain/cli/scripts/run.py rename to llama_stack/cli/scripts/run.py diff --git a/llama_toolchain/cli/stack/__init__.py b/llama_stack/cli/stack/__init__.py similarity index 100% rename from llama_toolchain/cli/stack/__init__.py rename to llama_stack/cli/stack/__init__.py diff --git a/llama_toolchain/cli/stack/build.py b/llama_stack/cli/stack/build.py similarity index 76% rename from llama_toolchain/cli/stack/build.py rename to llama_stack/cli/stack/build.py index 36cd480fc..f6f79b621 100644 --- a/llama_toolchain/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -6,8 +6,8 @@ import argparse -from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.core.datatypes import * # noqa: F403 +from llama_stack.cli.subcommand import Subcommand +from llama_stack.distribution.datatypes import * # noqa: F403 from pathlib import Path import yaml @@ -29,7 +29,7 @@ class StackBuild(Subcommand): self.parser.add_argument( "config", type=str, - help="Path to a config file to use for the build. You may find example configs in llama_toolchain/configs/distributions", + help="Path to a config file to use for the build. You may find example configs in llama_stack/distribution/example_configs", ) self.parser.add_argument( @@ -44,17 +44,17 @@ class StackBuild(Subcommand): import json import os - from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR - from llama_toolchain.common.serialize import EnumEncoder - from llama_toolchain.core.package import ApiInput, build_package, ImageType + from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR + from llama_stack.distribution.utils.serialize import EnumEncoder + from llama_stack.distribution.build import ApiInput, build_image, ImageType from termcolor import cprint # save build.yaml spec for building same distribution again if build_config.image_type == ImageType.docker.value: # docker needs build file to be in the llama-stack repo dir to be able to copy over to the image - llama_toolchain_path = Path(os.path.relpath(__file__)).parent.parent.parent + llama_stack_path = Path(os.path.relpath(__file__)).parent.parent.parent build_dir = ( - llama_toolchain_path / "configs/distributions" / build_config.image_type + llama_stack_path / "configs/distributions" / build_config.image_type ) else: build_dir = DISTRIBS_BASE_DIR / build_config.image_type @@ -66,7 +66,7 @@ class StackBuild(Subcommand): to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder)) f.write(yaml.dump(to_write, sort_keys=False)) - build_package(build_config, build_file_path) + build_image(build_config, build_file_path) cprint( f"Build spec configuration saved at {str(build_file_path)}", @@ -74,12 +74,12 @@ class StackBuild(Subcommand): ) def _run_stack_build_command(self, args: argparse.Namespace) -> None: - from llama_toolchain.common.prompt_for_config import prompt_for_config - from llama_toolchain.core.dynamic import instantiate_class_type + from llama_stack.distribution.utils.prompt_for_config import prompt_for_config + from llama_stack.distribution.utils.dynamic import instantiate_class_type if not args.config: self.parser.error( - "No config file specified. Please use `llama stack build /path/to/*-build.yaml`. Example config files can be found in llama_toolchain/configs/distributions" + "No config file specified. Please use `llama stack build /path/to/*-build.yaml`. Example config files can be found in llama_stack/distribution/example_configs" ) return diff --git a/llama_toolchain/cli/stack/configure.py b/llama_stack/cli/stack/configure.py similarity index 82% rename from llama_toolchain/cli/stack/configure.py rename to llama_stack/cli/stack/configure.py index 4a73f1af4..d739aad50 100644 --- a/llama_toolchain/cli/stack/configure.py +++ b/llama_stack/cli/stack/configure.py @@ -13,11 +13,11 @@ import pkg_resources import yaml from termcolor import cprint -from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR +from llama_stack.cli.subcommand import Subcommand +from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR -from llama_toolchain.common.exec import run_with_pty -from llama_toolchain.core.datatypes import * # noqa: F403 +from llama_stack.distribution.utils.exec import run_with_pty +from llama_stack.distribution.datatypes import * # noqa: F403 import os @@ -49,7 +49,7 @@ class StackConfigure(Subcommand): ) def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None: - from llama_toolchain.core.package import ImageType + from llama_stack.distribution.build import ImageType docker_image = None build_config_file = Path(args.config) @@ -66,7 +66,7 @@ class StackConfigure(Subcommand): os.makedirs(builds_dir, exist_ok=True) script = pkg_resources.resource_filename( - "llama_toolchain", "core/configure_container.sh" + "llama_stack", "distribution/configure_container.sh" ) script_args = [script, docker_image, str(builds_dir)] @@ -95,8 +95,8 @@ class StackConfigure(Subcommand): build_config: BuildConfig, output_dir: Optional[str] = None, ): - from llama_toolchain.common.serialize import EnumEncoder - from llama_toolchain.core.configure import configure_api_providers + from llama_stack.distribution.configure import configure_api_providers + from llama_stack.distribution.utils.serialize import EnumEncoder builds_dir = BUILDS_BASE_DIR / build_config.image_type if output_dir: @@ -105,16 +105,9 @@ class StackConfigure(Subcommand): image_name = build_config.name.replace("::", "-") run_config_file = builds_dir / f"{image_name}-run.yaml" - api2providers = build_config.distribution_spec.providers - - stub_config = { - api_str: {"provider_id": provider} - for api_str, provider in api2providers.items() - } - if run_config_file.exists(): cprint( - f"Configuration already exists for {build_config.name}. Will overwrite...", + f"Configuration already exists at `{str(run_config_file)}`. Will overwrite...", "yellow", attrs=["bold"], ) @@ -123,10 +116,12 @@ class StackConfigure(Subcommand): config = StackRunConfig( built_at=datetime.now(), image_name=image_name, - providers=stub_config, + apis_to_serve=[], + provider_map={}, ) - config.providers = configure_api_providers(config.providers) + config = configure_api_providers(config, build_config.distribution_spec) + config.docker_image = ( image_name if build_config.image_type == "docker" else None ) diff --git a/llama_toolchain/cli/stack/list_apis.py b/llama_stack/cli/stack/list_apis.py similarity index 87% rename from llama_toolchain/cli/stack/list_apis.py rename to llama_stack/cli/stack/list_apis.py index f13ecefe9..cac803f92 100644 --- a/llama_toolchain/cli/stack/list_apis.py +++ b/llama_stack/cli/stack/list_apis.py @@ -6,7 +6,7 @@ import argparse -from llama_toolchain.cli.subcommand import Subcommand +from llama_stack.cli.subcommand import Subcommand class StackListApis(Subcommand): @@ -25,8 +25,8 @@ class StackListApis(Subcommand): pass def _run_apis_list_cmd(self, args: argparse.Namespace) -> None: - from llama_toolchain.cli.table import print_table - from llama_toolchain.core.distribution import stack_apis + from llama_stack.cli.table import print_table + from llama_stack.distribution.distribution import stack_apis # eventually, this should query a registry at llama.meta.com/llamastack/distributions headers = [ diff --git a/llama_toolchain/cli/stack/list_providers.py b/llama_stack/cli/stack/list_providers.py similarity index 87% rename from llama_toolchain/cli/stack/list_providers.py rename to llama_stack/cli/stack/list_providers.py index a5640677d..33cfe6939 100644 --- a/llama_toolchain/cli/stack/list_providers.py +++ b/llama_stack/cli/stack/list_providers.py @@ -6,7 +6,7 @@ import argparse -from llama_toolchain.cli.subcommand import Subcommand +from llama_stack.cli.subcommand import Subcommand class StackListProviders(Subcommand): @@ -22,7 +22,7 @@ class StackListProviders(Subcommand): self.parser.set_defaults(func=self._run_providers_list_cmd) def _add_arguments(self): - from llama_toolchain.core.distribution import stack_apis + from llama_stack.distribution.distribution import stack_apis api_values = [a.value for a in stack_apis()] self.parser.add_argument( @@ -33,8 +33,8 @@ class StackListProviders(Subcommand): ) def _run_providers_list_cmd(self, args: argparse.Namespace) -> None: - from llama_toolchain.cli.table import print_table - from llama_toolchain.core.distribution import Api, api_providers + from llama_stack.cli.table import print_table + from llama_stack.distribution.distribution import Api, api_providers all_providers = api_providers() providers_for_api = all_providers[Api(args.api)] diff --git a/llama_toolchain/cli/stack/run.py b/llama_stack/cli/stack/run.py similarity index 87% rename from llama_toolchain/cli/stack/run.py rename to llama_stack/cli/stack/run.py index b5900eaba..acdbcf3bc 100644 --- a/llama_toolchain/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -11,8 +11,8 @@ from pathlib import Path import pkg_resources import yaml -from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.core.datatypes import * # noqa: F403 +from llama_stack.cli.subcommand import Subcommand +from llama_stack.distribution.datatypes import * # noqa: F403 class StackRun(Subcommand): @@ -47,7 +47,7 @@ class StackRun(Subcommand): ) def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: - from llama_toolchain.common.exec import run_with_pty + from llama_stack.distribution.utils.exec import run_with_pty if not args.config: self.parser.error("Must specify a config file to run") @@ -67,14 +67,14 @@ class StackRun(Subcommand): if config.docker_image: script = pkg_resources.resource_filename( - "llama_toolchain", - "core/start_container.sh", + "llama_stack", + "distribution/start_container.sh", ) run_args = [script, config.docker_image] else: script = pkg_resources.resource_filename( - "llama_toolchain", - "core/start_conda_env.sh", + "llama_stack", + "distribution/start_conda_env.sh", ) run_args = [ script, diff --git a/llama_toolchain/cli/stack/stack.py b/llama_stack/cli/stack/stack.py similarity index 94% rename from llama_toolchain/cli/stack/stack.py rename to llama_stack/cli/stack/stack.py index 0e4abb5a2..c359d27ec 100644 --- a/llama_toolchain/cli/stack/stack.py +++ b/llama_stack/cli/stack/stack.py @@ -6,7 +6,7 @@ import argparse -from llama_toolchain.cli.subcommand import Subcommand +from llama_stack.cli.subcommand import Subcommand from .build import StackBuild from .configure import StackConfigure diff --git a/llama_toolchain/cli/subcommand.py b/llama_stack/cli/subcommand.py similarity index 100% rename from llama_toolchain/cli/subcommand.py rename to llama_stack/cli/subcommand.py diff --git a/llama_toolchain/cli/table.py b/llama_stack/cli/table.py similarity index 100% rename from llama_toolchain/cli/table.py rename to llama_stack/cli/table.py diff --git a/llama_toolchain/core/__init__.py b/llama_stack/distribution/__init__.py similarity index 100% rename from llama_toolchain/core/__init__.py rename to llama_stack/distribution/__init__.py diff --git a/llama_toolchain/core/package.py b/llama_stack/distribution/build.py similarity index 58% rename from llama_toolchain/core/package.py rename to llama_stack/distribution/build.py index 7987384e2..95cea6caa 100644 --- a/llama_toolchain/core/package.py +++ b/llama_stack/distribution/build.py @@ -4,26 +4,20 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -import os -from datetime import datetime from enum import Enum from typing import List, Optional import pkg_resources -import yaml - -from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR -from llama_toolchain.common.exec import run_with_pty -from llama_toolchain.common.serialize import EnumEncoder from pydantic import BaseModel from termcolor import cprint -from llama_toolchain.core.datatypes import * # noqa: F403 +from llama_stack.distribution.utils.exec import run_with_pty + +from llama_stack.distribution.datatypes import * # noqa: F403 from pathlib import Path -from llama_toolchain.core.distribution import api_providers, SERVER_DEPENDENCIES +from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES class ImageType(Enum): @@ -41,7 +35,7 @@ class ApiInput(BaseModel): provider: str -def build_package(build_config: BuildConfig, build_file_path: Path): +def build_image(build_config: BuildConfig, build_file_path: Path): package_deps = Dependencies( docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim", pip_packages=SERVER_DEPENDENCIES, @@ -49,21 +43,32 @@ def build_package(build_config: BuildConfig, build_file_path: Path): # extend package dependencies based on providers spec all_providers = api_providers() - for api_str, provider in build_config.distribution_spec.providers.items(): + for ( + api_str, + provider_or_providers, + ) in build_config.distribution_spec.providers.items(): providers_for_api = all_providers[Api(api_str)] - if provider not in providers_for_api: - raise ValueError( - f"Provider `{provider}` is not available for API `{api_str}`" - ) - provider_spec = providers_for_api[provider] - package_deps.pip_packages.extend(provider_spec.pip_packages) - if provider_spec.docker_image: - raise ValueError("A stack's dependencies cannot have a docker image") + providers = ( + provider_or_providers + if isinstance(provider_or_providers, list) + else [provider_or_providers] + ) + + for provider in providers: + if provider not in providers_for_api: + raise ValueError( + f"Provider `{provider}` is not available for API `{api_str}`" + ) + + provider_spec = providers_for_api[provider] + package_deps.pip_packages.extend(provider_spec.pip_packages) + if provider_spec.docker_image: + raise ValueError("A stack's dependencies cannot have a docker image") if build_config.image_type == ImageType.docker.value: script = pkg_resources.resource_filename( - "llama_toolchain", "core/build_container.sh" + "llama_stack", "distribution/build_container.sh" ) args = [ script, @@ -74,7 +79,7 @@ def build_package(build_config: BuildConfig, build_file_path: Path): ] else: script = pkg_resources.resource_filename( - "llama_toolchain", "core/build_conda_env.sh" + "llama_stack", "distribution/build_conda_env.sh" ) args = [ script, diff --git a/llama_toolchain/core/build_conda_env.sh b/llama_stack/distribution/build_conda_env.sh similarity index 81% rename from llama_toolchain/core/build_conda_env.sh rename to llama_stack/distribution/build_conda_env.sh index 0d0ac82fc..b210a8c8b 100755 --- a/llama_toolchain/core/build_conda_env.sh +++ b/llama_stack/distribution/build_conda_env.sh @@ -7,11 +7,11 @@ # the root directory of this source tree. LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-} -LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-} +LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} -if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then - echo "Using llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR" +if [ -n "$LLAMA_STACK_DIR" ]; then + echo "Using llama-stack-dir=$LLAMA_STACK_DIR" fi if [ -n "$LLAMA_MODELS_DIR" ]; then echo "Using llama-models-dir=$LLAMA_MODELS_DIR" @@ -78,19 +78,19 @@ ensure_conda_env_python310() { if [ -n "$TEST_PYPI_VERSION" ]; then # these packages are damaged in test-pypi, so install them first pip install fastapi libcst - pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-toolchain==$TEST_PYPI_VERSION $pip_dependencies + pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION $pip_dependencies else - # Re-installing llama-toolchain in the new conda environment - if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then - if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then - printf "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}\n" >&2 + # Re-installing llama-stack in the new conda environment + if [ -n "$LLAMA_STACK_DIR" ]; then + if [ ! -d "$LLAMA_STACK_DIR" ]; then + printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2 exit 1 fi - printf "Installing from LLAMA_TOOLCHAIN_DIR: $LLAMA_TOOLCHAIN_DIR\n" - pip install --no-cache-dir -e "$LLAMA_TOOLCHAIN_DIR" + printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n" + pip install --no-cache-dir -e "$LLAMA_STACK_DIR" else - pip install --no-cache-dir llama-toolchain + pip install --no-cache-dir llama-stack fi if [ -n "$LLAMA_MODELS_DIR" ]; then diff --git a/llama_toolchain/core/build_container.sh b/llama_stack/distribution/build_container.sh similarity index 80% rename from llama_toolchain/core/build_container.sh rename to llama_stack/distribution/build_container.sh index d829e8399..836f9bf19 100755 --- a/llama_toolchain/core/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -1,7 +1,7 @@ #!/bin/bash LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-} -LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-} +LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} if [ "$#" -ne 4 ]; then @@ -55,17 +55,17 @@ RUN apt-get update && apt-get install -y \ EOF -toolchain_mount="/app/llama-toolchain-source" +stack_mount="/app/llama-stack-source" models_mount="/app/llama-models-source" -if [ -n "$LLAMA_TOOLCHAIN_DIR" ]; then - if [ ! -d "$LLAMA_TOOLCHAIN_DIR" ]; then - echo "${RED}Warning: LLAMA_TOOLCHAIN_DIR is set but directory does not exist: $LLAMA_TOOLCHAIN_DIR${NC}" >&2 +if [ -n "$LLAMA_STACK_DIR" ]; then + if [ ! -d "$LLAMA_STACK_DIR" ]; then + echo "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}" >&2 exit 1 fi - add_to_docker "RUN pip install $toolchain_mount" + add_to_docker "RUN pip install $stack_mount" else - add_to_docker "RUN pip install llama-toolchain" + add_to_docker "RUN pip install llama-stack" fi if [ -n "$LLAMA_MODELS_DIR" ]; then @@ -90,7 +90,7 @@ add_to_docker < StackRunConfig: + cprint("Configuring APIs to serve...", "white", attrs=["bold"]) + print("Enter comma-separated list of APIs to serve:") + + apis = config.apis_to_serve or list(spec.providers.keys()) + apis = [a for a in apis if a != "telemetry"] + req_apis = ReqApis( + apis_to_serve=apis, + ) + req_apis = prompt_for_config(ReqApis, req_apis) + config.apis_to_serve = req_apis.apis_to_serve + print("") + + apis = [v.value for v in stack_apis()] + all_providers = api_providers() + + for api_str in spec.providers.keys(): + if api_str not in apis: + raise ValueError(f"Unknown API `{api_str}`") + + cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"]) + api = Api(api_str) + + provider_or_providers = spec.providers[api_str] + if isinstance(provider_or_providers, list) and len(provider_or_providers) > 1: + print( + "You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n" + ) + + routing_entries = [] + for p in provider_or_providers: + print(f"Configuring provider `{p}`...") + provider_spec = all_providers[api][p] + config_type = instantiate_class_type(provider_spec.config_class) + + # TODO: we need to validate the routing keys, and + # perhaps it is better if we break this out into asking + # for a routing key separately from the associated config + wrapper_type = make_routing_entry_type(config_type) + rt_entry = prompt_for_config(wrapper_type, None) + + routing_entries.append( + ProviderRoutingEntry( + provider_id=p, + routing_key=rt_entry.routing_key, + config=rt_entry.config.dict(), + ) + ) + config.provider_map[api_str] = routing_entries + else: + p = ( + provider_or_providers[0] + if isinstance(provider_or_providers, list) + else provider_or_providers + ) + print(f"Configuring provider `{p}`...") + provider_spec = all_providers[api][p] + config_type = instantiate_class_type(provider_spec.config_class) + try: + provider_config = config.provider_map.get(api_str) + if provider_config: + existing = config_type(**provider_config.config) + else: + existing = None + except Exception: + existing = None + cfg = prompt_for_config(config_type, existing) + config.provider_map[api_str] = GenericProviderConfig( + provider_id=p, + config=cfg.dict(), + ) + + return config diff --git a/llama_toolchain/core/configure_container.sh b/llama_stack/distribution/configure_container.sh similarity index 100% rename from llama_toolchain/core/configure_container.sh rename to llama_stack/distribution/configure_container.sh diff --git a/llama_stack/distribution/control_plane/__init__.py b/llama_stack/distribution/control_plane/__init__.py new file mode 100644 index 000000000..5abb4e730 --- /dev/null +++ b/llama_stack/distribution/control_plane/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .control_plane import * # noqa: F401 F403 diff --git a/llama_toolchain/inference/__init__.py b/llama_stack/distribution/control_plane/adapters/__init__.py similarity index 100% rename from llama_toolchain/inference/__init__.py rename to llama_stack/distribution/control_plane/adapters/__init__.py diff --git a/llama_stack/distribution/control_plane/adapters/redis/__init__.py b/llama_stack/distribution/control_plane/adapters/redis/__init__.py new file mode 100644 index 000000000..0482718cc --- /dev/null +++ b/llama_stack/distribution/control_plane/adapters/redis/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import RedisImplConfig + + +async def get_adapter_impl(config: RedisImplConfig, _deps): + from .redis import RedisControlPlaneAdapter + + impl = RedisControlPlaneAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/distribution/control_plane/adapters/redis/config.py b/llama_stack/distribution/control_plane/adapters/redis/config.py new file mode 100644 index 000000000..6238611e0 --- /dev/null +++ b/llama_stack/distribution/control_plane/adapters/redis/config.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class RedisImplConfig(BaseModel): + url: str = Field( + description="The URL for the Redis server", + ) + namespace: Optional[str] = Field( + default=None, + description="All keys will be prefixed with this namespace", + ) diff --git a/llama_stack/distribution/control_plane/adapters/redis/redis.py b/llama_stack/distribution/control_plane/adapters/redis/redis.py new file mode 100644 index 000000000..d5c468b77 --- /dev/null +++ b/llama_stack/distribution/control_plane/adapters/redis/redis.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime, timedelta +from typing import Any, List, Optional + +from redis.asyncio import Redis + +from llama_stack.apis.control_plane import * # noqa: F403 + + +from .config import RedisImplConfig + + +class RedisControlPlaneAdapter(ControlPlane): + def __init__(self, config: RedisImplConfig): + self.config = config + + async def initialize(self) -> None: + self.redis = Redis.from_url(self.config.url) + + def _namespaced_key(self, key: str) -> str: + if not self.config.namespace: + return key + return f"{self.config.namespace}:{key}" + + async def set( + self, key: str, value: Any, expiration: Optional[datetime] = None + ) -> None: + key = self._namespaced_key(key) + await self.redis.set(key, value) + if expiration: + await self.redis.expireat(key, expiration) + + async def get(self, key: str) -> Optional[ControlPlaneValue]: + key = self._namespaced_key(key) + value = await self.redis.get(key) + if value is None: + return None + ttl = await self.redis.ttl(key) + expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None + return ControlPlaneValue(key=key, value=value, expiration=expiration) + + async def delete(self, key: str) -> None: + key = self._namespaced_key(key) + await self.redis.delete(key) + + async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + + keys = await self.redis.keys(f"{start_key}*") + result = [] + for key in keys: + if key <= end_key: + value = await self.get(key) + if value: + result.append(value) + return result diff --git a/llama_stack/distribution/control_plane/adapters/sqlite/__init__.py b/llama_stack/distribution/control_plane/adapters/sqlite/__init__.py new file mode 100644 index 000000000..330f15942 --- /dev/null +++ b/llama_stack/distribution/control_plane/adapters/sqlite/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import SqliteControlPlaneConfig + + +async def get_provider_impl(config: SqliteControlPlaneConfig, _deps): + from .control_plane import SqliteControlPlane + + impl = SqliteControlPlane(config) + await impl.initialize() + return impl diff --git a/llama_stack/distribution/control_plane/adapters/sqlite/config.py b/llama_stack/distribution/control_plane/adapters/sqlite/config.py new file mode 100644 index 000000000..a616c90d0 --- /dev/null +++ b/llama_stack/distribution/control_plane/adapters/sqlite/config.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class SqliteControlPlaneConfig(BaseModel): + db_path: str = Field( + description="File path for the sqlite database", + ) + table_name: str = Field( + default="llamastack_control_plane", + description="Table into which all the keys will be placed", + ) diff --git a/llama_stack/distribution/control_plane/adapters/sqlite/control_plane.py b/llama_stack/distribution/control_plane/adapters/sqlite/control_plane.py new file mode 100644 index 000000000..e2e655244 --- /dev/null +++ b/llama_stack/distribution/control_plane/adapters/sqlite/control_plane.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from datetime import datetime +from typing import Any, List, Optional + +import aiosqlite + +from llama_stack.apis.control_plane import * # noqa: F403 + + +from .config import SqliteControlPlaneConfig + + +class SqliteControlPlane(ControlPlane): + def __init__(self, config: SqliteControlPlaneConfig): + self.db_path = config.db_path + self.table_name = config.table_name + + async def initialize(self): + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + key TEXT PRIMARY KEY, + value TEXT, + expiration TIMESTAMP + ) + """ + ) + await db.commit() + + async def set( + self, key: str, value: Any, expiration: Optional[datetime] = None + ) -> None: + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", + (key, json.dumps(value), expiration), + ) + await db.commit() + + async def get(self, key: str) -> Optional[ControlPlaneValue]: + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) + ) as cursor: + row = await cursor.fetchone() + if row is None: + return None + value, expiration = row + return ControlPlaneValue( + key=key, value=json.loads(value), expiration=expiration + ) + + async def delete(self, key: str) -> None: + async with aiosqlite.connect(self.db_path) as db: + await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) + await db.commit() + + async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) as cursor: + result = [] + async for row in cursor: + key, value, expiration = row + result.append( + ControlPlaneValue( + key=key, value=json.loads(value), expiration=expiration + ) + ) + return result diff --git a/llama_stack/distribution/control_plane/api.py b/llama_stack/distribution/control_plane/api.py new file mode 100644 index 000000000..db79e91cd --- /dev/null +++ b/llama_stack/distribution/control_plane/api.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime +from typing import Any, List, Optional, Protocol + +from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel + + +@json_schema_type +class ControlPlaneValue(BaseModel): + key: str + value: Any + expiration: Optional[datetime] = None + + +@json_schema_type +class ControlPlane(Protocol): + @webmethod(route="/control_plane/set") + async def set( + self, key: str, value: Any, expiration: Optional[datetime] = None + ) -> None: ... + + @webmethod(route="/control_plane/get", method="GET") + async def get(self, key: str) -> Optional[ControlPlaneValue]: ... + + @webmethod(route="/control_plane/delete") + async def delete(self, key: str) -> None: ... + + @webmethod(route="/control_plane/range", method="GET") + async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: ... diff --git a/llama_stack/distribution/control_plane/registry.py b/llama_stack/distribution/control_plane/registry.py new file mode 100644 index 000000000..7465c4534 --- /dev/null +++ b/llama_stack/distribution/control_plane/registry.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.control_plane, + provider_id="sqlite", + pip_packages=["aiosqlite"], + module="llama_stack.providers.impls.sqlite.control_plane", + config_class="llama_stack.providers.impls.sqlite.control_plane.SqliteControlPlaneConfig", + ), + remote_provider_spec( + Api.control_plane, + AdapterSpec( + adapter_id="redis", + pip_packages=["redis"], + module="llama_stack.providers.adapters.control_plane.redis", + ), + ), + ] diff --git a/llama_toolchain/core/datatypes.py b/llama_stack/distribution/datatypes.py similarity index 72% rename from llama_toolchain/core/datatypes.py rename to llama_stack/distribution/datatypes.py index f523e0308..e57617016 100644 --- a/llama_toolchain/core/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -6,7 +6,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from llama_models.schema_utils import json_schema_type @@ -17,7 +17,7 @@ from pydantic import BaseModel, Field, validator class Api(Enum): inference = "inference" safety = "safety" - agentic_system = "agentic_system" + agents = "agents" memory = "memory" telemetry = "telemetry" @@ -43,6 +43,33 @@ class ProviderSpec(BaseModel): ) +@json_schema_type +class RouterProviderSpec(ProviderSpec): + provider_id: str = "router" + config_class: str = "" + + docker_image: Optional[str] = None + + inner_specs: List[ProviderSpec] + module: str = Field( + ..., + description=""" +Fully-qualified name of the module to import. The module is expected to have: + + - `get_router_impl(config, provider_specs, deps)`: returns the router implementation +""", + ) + + @property + def pip_packages(self) -> List[str]: + raise AssertionError("Should not be called on RouterProviderSpec") + + +class GenericProviderConfig(BaseModel): + provider_id: str + config: Dict[str, Any] + + @json_schema_type class AdapterSpec(BaseModel): adapter_id: str = Field( @@ -124,7 +151,7 @@ as being "Llama Stack compatible" def module(self) -> str: if self.adapter: return self.adapter.module - return f"llama_toolchain.{self.api.value}.client" + return f"llama_stack.apis.{self.api.value}.client" @property def pip_packages(self) -> List[str]: @@ -140,7 +167,7 @@ def remote_provider_spec( config_class = ( adapter.config_class if adapter and adapter.config_class - else "llama_toolchain.core.datatypes.RemoteProviderConfig" + else "llama_stack.distribution.datatypes.RemoteProviderConfig" ) provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote" @@ -156,12 +183,23 @@ class DistributionSpec(BaseModel): description="Description of the distribution", ) docker_image: Optional[str] = None - providers: Dict[str, str] = Field( + providers: Dict[str, Union[str, List[str]]] = Field( default_factory=dict, - description="Provider Types for each of the APIs provided by this distribution", + description=""" +Provider Types for each of the APIs provided by this distribution. If you +select multiple providers, you should provide an appropriate 'routing_map' +in the runtime configuration to help route to the correct provider.""", ) +@json_schema_type +class ProviderRoutingEntry(GenericProviderConfig): + routing_key: str + + +ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]] + + @json_schema_type class StackRunConfig(BaseModel): built_at: datetime @@ -181,12 +219,22 @@ this could be just a hash default=None, description="Reference to the conda environment if this package refers to a conda environment", ) - providers: Dict[str, Any] = Field( - default_factory=dict, + apis_to_serve: List[str] = Field( description=""" -Provider configurations for each of the APIs provided by this package. This includes configurations for -the dependencies of these providers as well. -""", +The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", + ) + provider_map: Dict[str, ProviderMapEntry] = Field( + description=""" +Provider configurations for each of the APIs provided by this package. + +Given an API, you can specify a single provider or a "routing table". Each entry in the routing +table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific. + +As examples: +- the "inference" API interprets the routing_key as a "model" +- the "memory" API interprets the routing_key as the type of a "memory bank" + +The key may support wild-cards alsothe routing_key to route to the correct provider.""", ) diff --git a/llama_toolchain/core/distribution.py b/llama_stack/distribution/distribution.py similarity index 79% rename from llama_toolchain/core/distribution.py rename to llama_stack/distribution/distribution.py index dc81b53f1..0825121dc 100644 --- a/llama_toolchain/core/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -8,18 +8,19 @@ import importlib import inspect from typing import Dict, List -from llama_toolchain.agentic_system.api import AgenticSystem -from llama_toolchain.inference.api import Inference -from llama_toolchain.memory.api import Memory -from llama_toolchain.safety.api import Safety -from llama_toolchain.telemetry.api import Telemetry +from llama_stack.apis.agents import Agents +from llama_stack.apis.inference import Inference +from llama_stack.apis.memory import Memory +from llama_stack.apis.safety import Safety +from llama_stack.apis.telemetry import Telemetry from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec # These are the dependencies needed by the distribution server. -# `llama-toolchain` is automatically installed by the installation script. +# `llama-stack` is automatically installed by the installation script. SERVER_DEPENDENCIES = [ "fastapi", + "fire", "uvicorn", ] @@ -34,7 +35,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: protocols = { Api.inference: Inference, Api.safety: Safety, - Api.agentic_system: AgenticSystem, + Api.agents: Agents, Api.memory: Memory, Api.telemetry: Telemetry, } @@ -67,7 +68,7 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]: ret = {} for api in stack_apis(): name = api.name.lower() - module = importlib.import_module(f"llama_toolchain.{name}.providers") + module = importlib.import_module(f"llama_stack.providers.registry.{name}") ret[api] = { "remote": remote_provider_spec(api), **{a.provider_id: a for a in module.available_providers()}, diff --git a/llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml b/llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml new file mode 100644 index 000000000..f98e48570 --- /dev/null +++ b/llama_stack/distribution/example_configs/conda/local-conda-example-build.yaml @@ -0,0 +1,10 @@ +name: local-conda-example +distribution_spec: + description: Use code from `llama_stack` itself to serve all llama stack APIs + providers: + inference: meta-reference + memory: meta-reference + safety: meta-reference + agents: meta-reference + telemetry: meta-reference +image_type: conda diff --git a/llama_toolchain/configs/distributions/conda/local-fireworks-conda-example-build.yaml b/llama_stack/distribution/example_configs/conda/local-fireworks-conda-example-build.yaml similarity index 69% rename from llama_toolchain/configs/distributions/conda/local-fireworks-conda-example-build.yaml rename to llama_stack/distribution/example_configs/conda/local-fireworks-conda-example-build.yaml index c3b38aebe..5e17f83c0 100644 --- a/llama_toolchain/configs/distributions/conda/local-fireworks-conda-example-build.yaml +++ b/llama_stack/distribution/example_configs/conda/local-fireworks-conda-example-build.yaml @@ -3,8 +3,8 @@ distribution_spec: description: Use Fireworks.ai for running LLM inference providers: inference: remote::fireworks - memory: meta-reference-faiss + memory: meta-reference safety: meta-reference - agentic_system: meta-reference - telemetry: console + agents: meta-reference + telemetry: meta-reference image_type: conda diff --git a/llama_toolchain/configs/distributions/conda/local-ollama-conda-example-build.yaml b/llama_stack/distribution/example_configs/conda/local-ollama-conda-example-build.yaml similarity index 69% rename from llama_toolchain/configs/distributions/conda/local-ollama-conda-example-build.yaml rename to llama_stack/distribution/example_configs/conda/local-ollama-conda-example-build.yaml index 31bc9d0e9..1c43e5998 100644 --- a/llama_toolchain/configs/distributions/conda/local-ollama-conda-example-build.yaml +++ b/llama_stack/distribution/example_configs/conda/local-ollama-conda-example-build.yaml @@ -3,8 +3,8 @@ distribution_spec: description: Like local, but use ollama for running LLM inference providers: inference: remote::ollama - memory: meta-reference-faiss + memory: meta-reference safety: meta-reference - agentic_system: meta-reference - telemetry: console + agents: meta-reference + telemetry: meta-reference image_type: conda diff --git a/llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml b/llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml similarity index 77% rename from llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml rename to llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml index 1ac6f44ba..07848b130 100644 --- a/llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml +++ b/llama_stack/distribution/example_configs/conda/local-tgi-conda-example-build.yaml @@ -3,8 +3,8 @@ distribution_spec: description: Use TGI (local or with Hugging Face Inference Endpoints for running LLM inference. When using HF Inference Endpoints, you must provide the name of the endpoint). providers: inference: remote::tgi - memory: meta-reference-faiss + memory: meta-reference safety: meta-reference - agentic_system: meta-reference - telemetry: console + agents: meta-reference + telemetry: meta-reference image_type: conda diff --git a/llama_toolchain/configs/distributions/conda/local-together-conda-example-build.yaml b/llama_stack/distribution/example_configs/conda/local-together-conda-example-build.yaml similarity index 68% rename from llama_toolchain/configs/distributions/conda/local-together-conda-example-build.yaml rename to llama_stack/distribution/example_configs/conda/local-together-conda-example-build.yaml index 4aa13fed5..df11bd0ed 100644 --- a/llama_toolchain/configs/distributions/conda/local-together-conda-example-build.yaml +++ b/llama_stack/distribution/example_configs/conda/local-together-conda-example-build.yaml @@ -3,8 +3,8 @@ distribution_spec: description: Use Together.ai for running LLM inference providers: inference: remote::together - memory: meta-reference-faiss + memory: meta-reference safety: meta-reference - agentic_system: meta-reference - telemetry: console + agents: meta-reference + telemetry: meta-reference image_type: conda diff --git a/llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml b/llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml new file mode 100644 index 000000000..885b6e58c --- /dev/null +++ b/llama_stack/distribution/example_configs/docker/local-docker-example-build.yaml @@ -0,0 +1,10 @@ +name: local-docker-example +distribution_spec: + description: Use code from `llama_stack` itself to serve all llama stack APIs + providers: + inference: meta-reference + memory: meta-reference + safety: meta-reference + agents: meta-reference + telemetry: meta-reference +image_type: docker diff --git a/llama_toolchain/inference/adapters/__init__.py b/llama_stack/distribution/server/__init__.py similarity index 100% rename from llama_toolchain/inference/adapters/__init__.py rename to llama_stack/distribution/server/__init__.py diff --git a/llama_toolchain/core/server.py b/llama_stack/distribution/server/server.py similarity index 82% rename from llama_toolchain/core/server.py rename to llama_stack/distribution/server/server.py index 7082ec765..16d24cad5 100644 --- a/llama_toolchain/core/server.py +++ b/llama_stack/distribution/server/server.py @@ -9,6 +9,7 @@ import inspect import json import signal import traceback + from collections.abc import ( AsyncGenerator as AsyncGeneratorABC, AsyncIterator as AsyncIteratorABC, @@ -38,16 +39,16 @@ from pydantic import BaseModel, ValidationError from termcolor import cprint from typing_extensions import Annotated -from llama_toolchain.telemetry.tracing import ( +from llama_stack.providers.utils.telemetry.tracing import ( end_trace, setup_logger, SpanStatus, start_trace, ) +from llama_stack.distribution.datatypes import * # noqa: F403 -from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec -from .distribution import api_endpoints, api_providers -from .dynamic import instantiate_provider +from llama_stack.distribution.distribution import api_endpoints, api_providers +from llama_stack.distribution.utils.dynamic import instantiate_provider def is_async_iterator_type(typ): @@ -271,61 +272,80 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: return [by_id[x] for x in stack] -def resolve_impls( - provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any] -) -> Dict[Api, Any]: - provider_configs = config["providers"] - provider_specs = topological_sort(provider_specs.values()) +def snake_to_camel(snake_str): + return "".join(word.capitalize() for word in snake_str.split("_")) - impls = {} - for provider_spec in provider_specs: - api = provider_spec.api - if api.value not in provider_configs: - raise ValueError( - f"Could not find provider_spec config for {api}. Please add it to the config" + +async def resolve_impls( + provider_map: Dict[str, ProviderMapEntry], +) -> Dict[Api, Any]: + """ + Does two things: + - flatmaps, sorts and resolves the providers in dependency order + - for each API, produces either a (local, passthrough or router) implementation + """ + all_providers = api_providers() + + specs = {} + for api_str, item in provider_map.items(): + api = Api(api_str) + providers = all_providers[api] + + if isinstance(item, GenericProviderConfig): + if item.provider_id not in providers: + raise ValueError( + f"Unknown provider `{provider_id}` is not available for API `{api}`" + ) + specs[api] = providers[item.provider_id] + else: + assert isinstance(item, list) + inner_specs = [] + for rt_entry in item: + if rt_entry.provider_id not in providers: + raise ValueError( + f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" + ) + inner_specs.append(providers[rt_entry.provider_id]) + + specs[api] = RouterProviderSpec( + api=api, + module=f"llama_stack.providers.routers.{api.value.lower()}", + api_dependencies=[], + inner_specs=inner_specs, ) - if isinstance(provider_spec, InlineProviderSpec): - deps = {api: impls[api] for api in provider_spec.api_dependencies} - else: - deps = {} - provider_config = provider_configs[api.value] - impl = instantiate_provider(provider_spec, provider_config, deps) + sorted_specs = topological_sort(specs.values()) + + impls = {} + for spec in sorted_specs: + api = spec.api + + deps = {api: impls[api] for api in spec.api_dependencies} + impl = await instantiate_provider(spec, deps, provider_map[api.value]) impls[api] = impl - return impls + return impls, specs def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): with open(yaml_config, "r") as fp: - config = yaml.safe_load(fp) + config = StackRunConfig(**yaml.safe_load(fp)) app = FastAPI() - all_endpoints = api_endpoints() - all_providers = api_providers() - - provider_specs = {} - for api_str, provider_config in config["providers"].items(): - api = Api(api_str) - providers = all_providers[api] - provider_id = provider_config["provider_id"] - if provider_id not in providers: - raise ValueError( - f"Unknown provider `{provider_id}` is not available for API `{api}`" - ) - - provider_specs[api] = providers[provider_id] - - impls = resolve_impls(provider_specs, config) + impls, specs = asyncio.run(resolve_impls(config.provider_map)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) - for provider_spec in provider_specs.values(): - api = provider_spec.api + all_endpoints = api_endpoints() + + apis_to_serve = config.apis_to_serve or list(config.provider_map.keys()) + for api_str in apis_to_serve: + api = Api(api_str) endpoints = all_endpoints[api] impl = impls[api] + provider_spec = specs[api] if ( isinstance(provider_spec, RemoteProviderSpec) and provider_spec.adapter is None diff --git a/llama_toolchain/core/start_conda_env.sh b/llama_stack/distribution/start_conda_env.sh similarity index 94% rename from llama_toolchain/core/start_conda_env.sh rename to llama_stack/distribution/start_conda_env.sh index 120dda006..3d91564b8 100755 --- a/llama_toolchain/core/start_conda_env.sh +++ b/llama_stack/distribution/start_conda_env.sh @@ -37,6 +37,6 @@ eval "$(conda shell.bash hook)" conda deactivate && conda activate "$env_name" $CONDA_PREFIX/bin/python \ - -m llama_toolchain.core.server \ + -m llama_stack.distribution.server.server \ --yaml_config "$yaml_config" \ --port "$port" "$@" diff --git a/llama_toolchain/core/start_container.sh b/llama_stack/distribution/start_container.sh similarity index 93% rename from llama_toolchain/core/start_container.sh rename to llama_stack/distribution/start_container.sh index 676bcedcf..ac5fbb565 100755 --- a/llama_toolchain/core/start_container.sh +++ b/llama_stack/distribution/start_container.sh @@ -38,6 +38,6 @@ podman run -it \ -p $port:$port \ -v "$yaml_config:/app/config.yaml" \ $docker_image \ - python -m llama_toolchain.core.server \ + python -m llama_stack.distribution.server.server \ --yaml_config /app/config.yaml \ --port $port "$@" diff --git a/llama_toolchain/memory/__init__.py b/llama_stack/distribution/utils/__init__.py similarity index 100% rename from llama_toolchain/memory/__init__.py rename to llama_stack/distribution/utils/__init__.py diff --git a/llama_toolchain/common/config_dirs.py b/llama_stack/distribution/utils/config_dirs.py similarity index 100% rename from llama_toolchain/common/config_dirs.py rename to llama_stack/distribution/utils/config_dirs.py diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py new file mode 100644 index 000000000..002a738ae --- /dev/null +++ b/llama_stack/distribution/utils/dynamic.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import importlib +from typing import Any, Dict + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +def instantiate_class_type(fully_qualified_name): + module_name, class_name = fully_qualified_name.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + + +# returns a class implementing the protocol corresponding to the Api +async def instantiate_provider( + provider_spec: ProviderSpec, + deps: Dict[str, Any], + provider_config: ProviderMapEntry, +): + module = importlib.import_module(provider_spec.module) + + args = [] + if isinstance(provider_spec, RemoteProviderSpec): + if provider_spec.adapter: + method = "get_adapter_impl" + else: + method = "get_client_impl" + + assert isinstance(provider_config, GenericProviderConfig) + config_type = instantiate_class_type(provider_spec.config_class) + config = config_type(**provider_config.config) + args = [config, deps] + elif isinstance(provider_spec, RouterProviderSpec): + method = "get_router_impl" + + assert isinstance(provider_config, list) + inner_specs = {x.provider_id: x for x in provider_spec.inner_specs} + inner_impls = [] + for routing_entry in provider_config: + impl = await instantiate_provider( + inner_specs[routing_entry.provider_id], + deps, + routing_entry, + ) + inner_impls.append((routing_entry.routing_key, impl)) + + config = None + args = [inner_impls, deps] + else: + method = "get_provider_impl" + + assert isinstance(provider_config, GenericProviderConfig) + config_type = instantiate_class_type(provider_spec.config_class) + config = config_type(**provider_config.config) + args = [config, deps] + + fn = getattr(module, method) + impl = await fn(*args) + impl.__provider_spec__ = provider_spec + impl.__provider_config__ = config + return impl diff --git a/llama_toolchain/common/exec.py b/llama_stack/distribution/utils/exec.py similarity index 100% rename from llama_toolchain/common/exec.py rename to llama_stack/distribution/utils/exec.py diff --git a/llama_toolchain/common/model_utils.py b/llama_stack/distribution/utils/model_utils.py similarity index 100% rename from llama_toolchain/common/model_utils.py rename to llama_stack/distribution/utils/model_utils.py diff --git a/llama_toolchain/common/prompt_for_config.py b/llama_stack/distribution/utils/prompt_for_config.py similarity index 91% rename from llama_toolchain/common/prompt_for_config.py rename to llama_stack/distribution/utils/prompt_for_config.py index 4f92ec7d9..d9d778540 100644 --- a/llama_toolchain/common/prompt_for_config.py +++ b/llama_stack/distribution/utils/prompt_for_config.py @@ -27,6 +27,12 @@ def is_list_of_primitives(field_type): return False +def is_basemodel_without_fields(typ): + return ( + inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0 + ) + + def can_recurse(typ): return ( inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0 @@ -151,6 +157,11 @@ def prompt_for_config( if get_origin(field_type) is Literal: continue + # Skip fields with no type annotations + if is_basemodel_without_fields(field_type): + config_data[field_name] = field_type() + continue + if inspect.isclass(field_type) and issubclass(field_type, Enum): prompt = f"Choose {field_name} (options: {', '.join(e.name for e in field_type)}):" while True: @@ -254,6 +265,20 @@ def prompt_for_config( print(f"{str(e)}") continue + elif get_origin(field_type) is dict: + try: + value = json.loads(user_input) + if not isinstance(value, dict): + raise ValueError( + "Input must be a JSON-encoded dictionary" + ) + + except json.JSONDecodeError: + print( + "Invalid JSON. Please enter a valid JSON-encoded dict." + ) + continue + # Convert the input to the correct type elif inspect.isclass(field_type) and issubclass( field_type, BaseModel diff --git a/llama_toolchain/common/serialize.py b/llama_stack/distribution/utils/serialize.py similarity index 100% rename from llama_toolchain/common/serialize.py rename to llama_stack/distribution/utils/serialize.py diff --git a/llama_toolchain/memory/common/__init__.py b/llama_stack/providers/__init__.py similarity index 100% rename from llama_toolchain/memory/common/__init__.py rename to llama_stack/providers/__init__.py diff --git a/llama_toolchain/memory/meta_reference/__init__.py b/llama_stack/providers/adapters/__init__.py similarity index 100% rename from llama_toolchain/memory/meta_reference/__init__.py rename to llama_stack/providers/adapters/__init__.py diff --git a/llama_toolchain/safety/__init__.py b/llama_stack/providers/adapters/inference/__init__.py similarity index 100% rename from llama_toolchain/safety/__init__.py rename to llama_stack/providers/adapters/inference/__init__.py diff --git a/llama_toolchain/inference/adapters/fireworks/__init__.py b/llama_stack/providers/adapters/inference/fireworks/__init__.py similarity index 100% rename from llama_toolchain/inference/adapters/fireworks/__init__.py rename to llama_stack/providers/adapters/inference/fireworks/__init__.py diff --git a/llama_toolchain/inference/adapters/fireworks/config.py b/llama_stack/providers/adapters/inference/fireworks/config.py similarity index 100% rename from llama_toolchain/inference/adapters/fireworks/config.py rename to llama_stack/providers/adapters/inference/fireworks/config.py diff --git a/llama_toolchain/inference/adapters/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py similarity index 97% rename from llama_toolchain/inference/adapters/fireworks/fireworks.py rename to llama_stack/providers/adapters/inference/fireworks/fireworks.py index e51a730de..1e6f2e753 100644 --- a/llama_toolchain/inference/adapters/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -6,15 +6,16 @@ from typing import AsyncGenerator -from fireworks.client import Fireworks from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model -from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.inference.prepare_messages import prepare_messages +from fireworks.client import Fireworks + +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.utils.inference.prepare_messages import prepare_messages from .config import FireworksImplConfig @@ -81,7 +82,7 @@ class FireworksInferenceAdapter(Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - tools: Optional[List[ToolDefinition]] = list(), + tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, @@ -91,7 +92,7 @@ class FireworksInferenceAdapter(Inference): model=model, messages=messages, sampling_params=sampling_params, - tools=tools, + tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, stream=stream, diff --git a/llama_toolchain/inference/adapters/ollama/__init__.py b/llama_stack/providers/adapters/inference/ollama/__init__.py similarity index 85% rename from llama_toolchain/inference/adapters/ollama/__init__.py rename to llama_stack/providers/adapters/inference/ollama/__init__.py index 8369a00a5..2a1f7d140 100644 --- a/llama_toolchain/inference/adapters/ollama/__init__.py +++ b/llama_stack/providers/adapters/inference/ollama/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_toolchain.core.datatypes import RemoteProviderConfig +from llama_stack.distribution.datatypes import RemoteProviderConfig async def get_adapter_impl(config: RemoteProviderConfig, _deps): diff --git a/llama_toolchain/inference/adapters/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py similarity index 97% rename from llama_toolchain/inference/adapters/ollama/ollama.py rename to llama_stack/providers/adapters/inference/ollama/ollama.py index 92fbf7585..ea726ff75 100644 --- a/llama_toolchain/inference/adapters/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -12,10 +12,11 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model + from ollama import AsyncClient -from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.inference.prepare_messages import prepare_messages +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.utils.inference.prepare_messages import prepare_messages # TODO: Eventually this will move to the llama cli model list command # mapping of Model SKUs to ollama models @@ -89,7 +90,7 @@ class OllamaInferenceAdapter(Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - tools: Optional[List[ToolDefinition]] = list(), + tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, @@ -99,7 +100,7 @@ class OllamaInferenceAdapter(Inference): model=model, messages=messages, sampling_params=sampling_params, - tools=tools, + tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, stream=stream, diff --git a/llama_toolchain/inference/adapters/tgi/__init__.py b/llama_stack/providers/adapters/inference/tgi/__init__.py similarity index 100% rename from llama_toolchain/inference/adapters/tgi/__init__.py rename to llama_stack/providers/adapters/inference/tgi/__init__.py diff --git a/llama_toolchain/inference/adapters/tgi/config.py b/llama_stack/providers/adapters/inference/tgi/config.py similarity index 100% rename from llama_toolchain/inference/adapters/tgi/config.py rename to llama_stack/providers/adapters/inference/tgi/config.py diff --git a/llama_toolchain/inference/adapters/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py similarity index 97% rename from llama_toolchain/inference/adapters/tgi/tgi.py rename to llama_stack/providers/adapters/inference/tgi/tgi.py index 7b1028817..3be1f3e98 100644 --- a/llama_toolchain/inference/adapters/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -13,8 +13,8 @@ from huggingface_hub import HfApi, InferenceClient from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.tokenizer import Tokenizer -from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.inference.prepare_messages import prepare_messages +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.utils.inference.prepare_messages import prepare_messages from .config import TGIImplConfig @@ -87,7 +87,7 @@ class TGIAdapter(Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - tools: Optional[List[ToolDefinition]] = list(), + tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, @@ -97,7 +97,7 @@ class TGIAdapter(Inference): model=model, messages=messages, sampling_params=sampling_params, - tools=tools, + tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, stream=stream, diff --git a/llama_toolchain/inference/adapters/together/__init__.py b/llama_stack/providers/adapters/inference/together/__init__.py similarity index 100% rename from llama_toolchain/inference/adapters/together/__init__.py rename to llama_stack/providers/adapters/inference/together/__init__.py diff --git a/llama_toolchain/inference/adapters/together/config.py b/llama_stack/providers/adapters/inference/together/config.py similarity index 100% rename from llama_toolchain/inference/adapters/together/config.py rename to llama_stack/providers/adapters/inference/together/config.py diff --git a/llama_toolchain/inference/adapters/together/together.py b/llama_stack/providers/adapters/inference/together/together.py similarity index 97% rename from llama_toolchain/inference/adapters/together/together.py rename to llama_stack/providers/adapters/inference/together/together.py index 76403a85b..565130883 100644 --- a/llama_toolchain/inference/adapters/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -11,10 +11,11 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model + from together import Together -from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.inference.prepare_messages import prepare_messages +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.utils.inference.prepare_messages import prepare_messages from .config import TogetherImplConfig @@ -81,7 +82,7 @@ class TogetherInferenceAdapter(Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - tools: Optional[List[ToolDefinition]] = list(), + tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, @@ -92,7 +93,7 @@ class TogetherInferenceAdapter(Inference): model=model, messages=messages, sampling_params=sampling_params, - tools=tools, + tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, stream=stream, diff --git a/llama_toolchain/safety/meta_reference/shields/contrib/__init__.py b/llama_stack/providers/adapters/memory/__init__.py similarity index 100% rename from llama_toolchain/safety/meta_reference/shields/contrib/__init__.py rename to llama_stack/providers/adapters/memory/__init__.py diff --git a/llama_toolchain/memory/adapters/chroma/__init__.py b/llama_stack/providers/adapters/memory/chroma/__init__.py similarity index 85% rename from llama_toolchain/memory/adapters/chroma/__init__.py rename to llama_stack/providers/adapters/memory/chroma/__init__.py index c90a8e8ac..dfd5c5696 100644 --- a/llama_toolchain/memory/adapters/chroma/__init__.py +++ b/llama_stack/providers/adapters/memory/chroma/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_toolchain.core.datatypes import RemoteProviderConfig +from llama_stack.distribution.datatypes import RemoteProviderConfig async def get_adapter_impl(config: RemoteProviderConfig, _deps): diff --git a/llama_toolchain/memory/adapters/chroma/chroma.py b/llama_stack/providers/adapters/memory/chroma/chroma.py similarity index 97% rename from llama_toolchain/memory/adapters/chroma/chroma.py rename to llama_stack/providers/adapters/memory/chroma/chroma.py index f4952cd0e..15f5810a9 100644 --- a/llama_toolchain/memory/adapters/chroma/chroma.py +++ b/llama_stack/providers/adapters/memory/chroma/chroma.py @@ -12,10 +12,13 @@ from urllib.parse import urlparse import chromadb from numpy.typing import NDArray -from llama_toolchain.memory.api import * # noqa: F403 +from llama_stack.apis.memory import * # noqa: F403 -from llama_toolchain.memory.common.vector_store import BankWithIndex, EmbeddingIndex +from llama_stack.providers.utils.memory.vector_store import ( + BankWithIndex, + EmbeddingIndex, +) class ChromaIndex(EmbeddingIndex): diff --git a/llama_toolchain/memory/adapters/pgvector/__init__.py b/llama_stack/providers/adapters/memory/pgvector/__init__.py similarity index 100% rename from llama_toolchain/memory/adapters/pgvector/__init__.py rename to llama_stack/providers/adapters/memory/pgvector/__init__.py diff --git a/llama_toolchain/memory/adapters/pgvector/config.py b/llama_stack/providers/adapters/memory/pgvector/config.py similarity index 100% rename from llama_toolchain/memory/adapters/pgvector/config.py rename to llama_stack/providers/adapters/memory/pgvector/config.py diff --git a/llama_toolchain/memory/adapters/pgvector/pgvector.py b/llama_stack/providers/adapters/memory/pgvector/pgvector.py similarity index 98% rename from llama_toolchain/memory/adapters/pgvector/pgvector.py rename to llama_stack/providers/adapters/memory/pgvector/pgvector.py index 930d7720f..a5c84a1b2 100644 --- a/llama_toolchain/memory/adapters/pgvector/pgvector.py +++ b/llama_stack/providers/adapters/memory/pgvector/pgvector.py @@ -13,10 +13,10 @@ from numpy.typing import NDArray from psycopg2 import sql from psycopg2.extras import execute_values, Json from pydantic import BaseModel -from llama_toolchain.memory.api import * # noqa: F403 +from llama_stack.apis.memory import * # noqa: F403 -from llama_toolchain.memory.common.vector_store import ( +from llama_stack.providers.utils.memory.vector_store import ( ALL_MINILM_L6_V2_DIMENSION, BankWithIndex, EmbeddingIndex, diff --git a/llama_toolchain/telemetry/__init__.py b/llama_stack/providers/impls/__init__.py similarity index 100% rename from llama_toolchain/telemetry/__init__.py rename to llama_stack/providers/impls/__init__.py diff --git a/llama_toolchain/tools/__init__.py b/llama_stack/providers/impls/meta_reference/__init__.py similarity index 100% rename from llama_toolchain/tools/__init__.py rename to llama_stack/providers/impls/meta_reference/__init__.py diff --git a/llama_toolchain/agentic_system/meta_reference/__init__.py b/llama_stack/providers/impls/meta_reference/agents/__init__.py similarity index 79% rename from llama_toolchain/agentic_system/meta_reference/__init__.py rename to llama_stack/providers/impls/meta_reference/agents/__init__.py index b49cc4c84..b6f3e6456 100644 --- a/llama_toolchain/agentic_system/meta_reference/__init__.py +++ b/llama_stack/providers/impls/meta_reference/agents/__init__.py @@ -6,7 +6,7 @@ from typing import Dict -from llama_toolchain.core.datatypes import Api, ProviderSpec +from llama_stack.distribution.datatypes import Api, ProviderSpec from .config import MetaReferenceImplConfig @@ -14,13 +14,13 @@ from .config import MetaReferenceImplConfig async def get_provider_impl( config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec] ): - from .agentic_system import MetaReferenceAgenticSystemImpl + from .agents import MetaReferenceAgentsImpl assert isinstance( config, MetaReferenceImplConfig ), f"Unexpected config type: {type(config)}" - impl = MetaReferenceAgenticSystemImpl( + impl = MetaReferenceAgentsImpl( config, deps[Api.inference], deps[Api.memory], diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py similarity index 87% rename from llama_toolchain/agentic_system/meta_reference/agent_instance.py rename to llama_stack/providers/impls/meta_reference/agents/agent_instance.py index 202f42a3c..d7f10a4f5 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -20,19 +20,15 @@ import httpx from termcolor import cprint -from llama_toolchain.agentic_system.api import * # noqa: F403 -from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.memory.api import * # noqa: F403 -from llama_toolchain.safety.api import * # noqa: F403 - -from llama_toolchain.tools.base import BaseTool -from llama_toolchain.tools.builtin import ( - interpret_content_as_attachment, - SingleMessageBuiltinTool, -) +from llama_stack.apis.agents import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.safety import * # noqa: F403 from .rag.context_retriever import generate_rag_query from .safety import SafetyException, ShieldRunnerMixin +from .tools.base import BaseTool +from .tools.builtin import interpret_content_as_attachment, SingleMessageBuiltinTool def make_random_string(length: int = 8): @@ -122,7 +118,7 @@ class ChatAgent(ShieldRunnerMixin): return session async def create_and_execute_turn( - self, request: AgenticSystemTurnCreateRequest + self, request: AgentTurnCreateRequest ) -> AsyncGenerator: assert ( request.session_id in self.sessions @@ -141,9 +137,9 @@ class ChatAgent(ShieldRunnerMixin): turn_id = str(uuid.uuid4()) start_time = datetime.now() - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseTurnStartPayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnStartPayload( turn_id=turn_id, ) ) @@ -169,12 +165,12 @@ class ChatAgent(ShieldRunnerMixin): continue assert isinstance( - chunk, AgenticSystemTurnResponseStreamChunk + chunk, AgentTurnResponseStreamChunk ), f"Unexpected type {type(chunk)}" event = chunk.event if ( event.payload.event_type - == AgenticSystemTurnResponseEventType.step_complete.value + == AgentTurnResponseEventType.step_complete.value ): steps.append(event.payload.step_details) @@ -193,9 +189,9 @@ class ChatAgent(ShieldRunnerMixin): ) session.turns.append(turn) - chunk = AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseTurnCompletePayload( + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( turn=turn, ) ) @@ -261,9 +257,9 @@ class ChatAgent(ShieldRunnerMixin): step_id = str(uuid.uuid4()) try: - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepStartPayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepStartPayload( step_type=StepType.shield_call.value, step_id=step_id, metadata=dict(touchpoint=touchpoint), @@ -273,9 +269,9 @@ class ChatAgent(ShieldRunnerMixin): await self.run_shields(messages, shields) except SafetyException as e: - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepCompletePayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( step_type=StepType.shield_call.value, step_details=ShieldCallStep( step_id=step_id, @@ -292,9 +288,9 @@ class ChatAgent(ShieldRunnerMixin): ) yield False - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepCompletePayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( step_type=StepType.shield_call.value, step_details=ShieldCallStep( step_id=step_id, @@ -325,9 +321,9 @@ class ChatAgent(ShieldRunnerMixin): ) if need_rag_context: step_id = str(uuid.uuid4()) - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepStartPayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepStartPayload( step_type=StepType.memory_retrieval.value, step_id=step_id, ) @@ -341,9 +337,9 @@ class ChatAgent(ShieldRunnerMixin): ) step_id = str(uuid.uuid4()) - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepCompletePayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( step_type=StepType.memory_retrieval.value, step_id=step_id, step_details=MemoryRetrievalStep( @@ -360,7 +356,7 @@ class ChatAgent(ShieldRunnerMixin): last_message = input_messages[-1] last_message.context = "\n".join(rag_context) - elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools: + elif attachments and AgentTool.code_interpreter.value in enabled_tools: urls = [a.content for a in attachments if isinstance(a.content, URL)] msg = await attachment_message(self.tempdir, urls) input_messages.append(msg) @@ -379,9 +375,9 @@ class ChatAgent(ShieldRunnerMixin): cprint(f"{str(msg)}", color=color) step_id = str(uuid.uuid4()) - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepStartPayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepStartPayload( step_type=StepType.inference.value, step_id=step_id, ) @@ -412,9 +408,9 @@ class ChatAgent(ShieldRunnerMixin): tool_calls.append(delta.content) if stream: - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepProgressPayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( step_type=StepType.inference.value, step_id=step_id, model_response_text_delta="", @@ -426,9 +422,9 @@ class ChatAgent(ShieldRunnerMixin): elif isinstance(delta, str): content += delta if stream and event.stop_reason is None: - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepProgressPayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( step_type=StepType.inference.value, step_id=step_id, model_response_text_delta=event.delta, @@ -448,9 +444,9 @@ class ChatAgent(ShieldRunnerMixin): tool_calls=tool_calls, ) - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepCompletePayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( step_type=StepType.inference.value, step_id=step_id, step_details=InferenceStep( @@ -498,17 +494,17 @@ class ChatAgent(ShieldRunnerMixin): return step_id = str(uuid.uuid4()) - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepStartPayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepStartPayload( step_type=StepType.tool_execution.value, step_id=step_id, ) ) ) - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepProgressPayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( step_type=StepType.tool_execution.value, step_id=step_id, tool_call=tool_call, @@ -525,9 +521,9 @@ class ChatAgent(ShieldRunnerMixin): ), "Currently not supporting multiple messages" result_message = result_messages[0] - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepCompletePayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( step_type=StepType.tool_execution.value, step_details=ToolExecutionStep( step_id=step_id, @@ -547,9 +543,9 @@ class ChatAgent(ShieldRunnerMixin): # TODO: add tool-input touchpoint and a "start" event for this step also # but that needs a lot more refactoring of Tool code potentially - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepCompletePayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( step_type=StepType.shield_call.value, step_details=ShieldCallStep( step_id=str(uuid.uuid4()), @@ -566,9 +562,9 @@ class ChatAgent(ShieldRunnerMixin): ) except SafetyException as e: - yield AgenticSystemTurnResponseStreamChunk( - event=AgenticSystemTurnResponseEvent( - payload=AgenticSystemTurnResponseStepCompletePayload( + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( step_type=StepType.shield_call.value, step_details=ShieldCallStep( step_id=str(uuid.uuid4()), @@ -616,18 +612,18 @@ class ChatAgent(ShieldRunnerMixin): enabled_tools = set(t.type for t in self.agent_config.tools) if attachments: if ( - AgenticSystemTool.code_interpreter.value in enabled_tools + AgentTool.code_interpreter.value in enabled_tools and self.agent_config.tool_choice == ToolChoice.required ): return False else: return True - return AgenticSystemTool.memory.value in enabled_tools + return AgentTool.memory.value in enabled_tools def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]: for t in self.agent_config.tools: - if t.type == AgenticSystemTool.memory.value: + if t.type == AgentTool.memory.value: return t return None diff --git a/llama_toolchain/agentic_system/meta_reference/agentic_system.py b/llama_stack/providers/impls/meta_reference/agents/agents.py similarity index 71% rename from llama_toolchain/agentic_system/meta_reference/agentic_system.py rename to llama_stack/providers/impls/meta_reference/agents/agents.py index 3990ab58a..022c8c3d1 100644 --- a/llama_toolchain/agentic_system/meta_reference/agentic_system.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -10,20 +10,20 @@ import tempfile import uuid from typing import AsyncGenerator -from llama_toolchain.inference.api import Inference -from llama_toolchain.memory.api import Memory -from llama_toolchain.safety.api import Safety -from llama_toolchain.agentic_system.api import * # noqa: F403 -from llama_toolchain.tools.builtin import ( +from llama_stack.apis.inference import Inference +from llama_stack.apis.memory import Memory +from llama_stack.apis.safety import Safety +from llama_stack.apis.agents import * # noqa: F403 + +from .agent_instance import ChatAgent +from .config import MetaReferenceImplConfig +from .tools.builtin import ( CodeInterpreterTool, PhotogenTool, SearchTool, WolframAlphaTool, ) -from llama_toolchain.tools.safety import with_safety - -from .agent_instance import ChatAgent -from .config import MetaReferenceImplConfig +from .tools.safety import with_safety logger = logging.getLogger() @@ -33,7 +33,7 @@ logger.setLevel(logging.INFO) AGENT_INSTANCES_BY_ID = {} -class MetaReferenceAgenticSystemImpl(AgenticSystem): +class MetaReferenceAgentsImpl(Agents): def __init__( self, config: MetaReferenceImplConfig, @@ -49,28 +49,18 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): async def initialize(self) -> None: pass - async def create_agentic_system( + async def create_agent( self, agent_config: AgentConfig, - ) -> AgenticSystemCreateResponse: + ) -> AgentCreateResponse: agent_id = str(uuid.uuid4()) builtin_tools = [] for tool_defn in agent_config.tools: if isinstance(tool_defn, WolframAlphaToolDefinition): - key = self.config.wolfram_api_key - if not key: - raise ValueError("Wolfram API key not defined in config") - tool = WolframAlphaTool(key) + tool = WolframAlphaTool(tool_defn.api_key) elif isinstance(tool_defn, SearchToolDefinition): - key = None - if tool_defn.engine == SearchEngineType.brave: - key = self.config.brave_search_api_key - elif tool_defn.engine == SearchEngineType.bing: - key = self.config.bing_search_api_key - if not key: - raise ValueError("API key not defined in config") - tool = SearchTool(tool_defn.engine, key) + tool = SearchTool(tool_defn.engine, tool_defn.api_key) elif isinstance(tool_defn, CodeInterpreterToolDefinition): tool = CodeInterpreterTool() elif isinstance(tool_defn, PhotogenToolDefinition): @@ -95,24 +85,24 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): builtin_tools=builtin_tools, ) - return AgenticSystemCreateResponse( + return AgentCreateResponse( agent_id=agent_id, ) - async def create_agentic_system_session( + async def create_agent_session( self, agent_id: str, session_name: str, - ) -> AgenticSystemSessionCreateResponse: + ) -> AgentSessionCreateResponse: assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found" agent = AGENT_INSTANCES_BY_ID[agent_id] session = agent.create_session(session_name) - return AgenticSystemSessionCreateResponse( + return AgentSessionCreateResponse( session_id=session.session_id, ) - async def create_agentic_system_turn( + async def create_agent_turn( self, agent_id: str, session_id: str, @@ -126,7 +116,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): stream: Optional[bool] = False, ) -> AsyncGenerator: # wrapper request to make it easier to pass around (internal only, not exposed to API) - request = AgenticSystemTurnCreateRequest( + request = AgentTurnCreateRequest( agent_id=agent_id, session_id=session_id, messages=messages, diff --git a/llama_stack/providers/impls/meta_reference/agents/config.py b/llama_stack/providers/impls/meta_reference/agents/config.py new file mode 100644 index 000000000..17beb348e --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/agents/config.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class MetaReferenceImplConfig(BaseModel): ... diff --git a/llama_toolchain/tools/custom/__init__.py b/llama_stack/providers/impls/meta_reference/agents/rag/__init__.py similarity index 100% rename from llama_toolchain/tools/custom/__init__.py rename to llama_stack/providers/impls/meta_reference/agents/rag/__init__.py diff --git a/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py b/llama_stack/providers/impls/meta_reference/agents/rag/context_retriever.py similarity index 95% rename from llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py rename to llama_stack/providers/impls/meta_reference/agents/rag/context_retriever.py index afcc6afd1..5ebb94a31 100644 --- a/llama_toolchain/agentic_system/meta_reference/rag/context_retriever.py +++ b/llama_stack/providers/impls/meta_reference/agents/rag/context_retriever.py @@ -10,14 +10,14 @@ from jinja2 import Template from llama_models.llama3.api import * # noqa: F403 -from llama_toolchain.agentic_system.api import ( +from llama_stack.apis.agents import ( DefaultMemoryQueryGeneratorConfig, LLMMemoryQueryGeneratorConfig, MemoryQueryGenerator, MemoryQueryGeneratorConfig, ) from termcolor import cprint # noqa: F401 -from llama_toolchain.inference.api import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 async def generate_rag_query( diff --git a/llama_toolchain/agentic_system/meta_reference/safety.py b/llama_stack/providers/impls/meta_reference/agents/safety.py similarity index 98% rename from llama_toolchain/agentic_system/meta_reference/safety.py rename to llama_stack/providers/impls/meta_reference/agents/safety.py index 4bbb1f2f1..f7148ddce 100644 --- a/llama_toolchain/agentic_system/meta_reference/safety.py +++ b/llama_stack/providers/impls/meta_reference/agents/safety.py @@ -7,15 +7,15 @@ from typing import List from llama_models.llama3.api.datatypes import Message, Role, UserMessage -from termcolor import cprint -from llama_toolchain.safety.api import ( +from llama_stack.apis.safety import ( OnViolationAction, RunShieldRequest, Safety, ShieldDefinition, ShieldResponse, ) +from termcolor import cprint class SafetyException(Exception): # noqa: N818 diff --git a/llama_toolchain/tools/ipython_tool/__init__.py b/llama_stack/providers/impls/meta_reference/agents/tests/__init__.py similarity index 100% rename from llama_toolchain/tools/ipython_tool/__init__.py rename to llama_stack/providers/impls/meta_reference/agents/tests/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/agents/tests/code_execution.py b/llama_stack/providers/impls/meta_reference/agents/tests/code_execution.py new file mode 100644 index 000000000..495cd2c92 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/agents/tests/code_execution.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import unittest + +from llama_models.llama3.api.datatypes import ( + Attachment, + BuiltinTool, + CompletionMessage, + StopReason, + ToolCall, +) + +from ..tools.builtin import CodeInterpreterTool + + +class TestCodeInterpreter(unittest.IsolatedAsyncioTestCase): + async def test_matplotlib(self): + tool = CodeInterpreterTool() + code = """ +import matplotlib.pyplot as plt +import numpy as np + +x = np.array([1, 1]) +y = np.array([0, 10]) + +plt.plot(x, y) +plt.title('x = 1') +plt.xlabel('x') +plt.ylabel('y') +plt.grid(True) +plt.axvline(x=1, color='r') +plt.show() + """ + message = CompletionMessage( + role="assistant", + content="", + tool_calls=[ + ToolCall( + call_id="call_id", + tool_name=BuiltinTool.code_interpreter, + arguments={"code": code}, + ) + ], + stop_reason=StopReason.end_of_message, + ) + ret = await tool.run([message]) + + self.assertEqual(len(ret), 1) + + output = ret[0].content + self.assertIsInstance(output, Attachment) + self.assertEqual(output.mime_type, "image/png") + + async def test_path_unlink(self): + tool = CodeInterpreterTool() + code = """ +import os +from pathlib import Path +import tempfile + +dpath = Path(os.environ["MPLCONFIGDIR"]) +with open(dpath / "test", "w") as f: + f.write("hello") + +Path(dpath / "test").unlink() +print("_OK_") + """ + message = CompletionMessage( + role="assistant", + content="", + tool_calls=[ + ToolCall( + call_id="call_id", + tool_name=BuiltinTool.code_interpreter, + arguments={"code": code}, + ) + ], + stop_reason=StopReason.end_of_message, + ) + ret = await tool.run([message]) + + self.assertEqual(len(ret), 1) + + output = ret[0].content + self.assertTrue("_OK_" in output) + + +if __name__ == "__main__": + unittest.main() diff --git a/llama_toolchain/dataset/api/__init__.py b/llama_stack/providers/impls/meta_reference/agents/tools/__init__.py similarity index 83% rename from llama_toolchain/dataset/api/__init__.py rename to llama_stack/providers/impls/meta_reference/agents/tools/__init__.py index a7e55ba91..756f351d8 100644 --- a/llama_toolchain/dataset/api/__init__.py +++ b/llama_stack/providers/impls/meta_reference/agents/tools/__init__.py @@ -3,5 +3,3 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/tools/base.py b/llama_stack/providers/impls/meta_reference/agents/tools/base.py similarity index 90% rename from llama_toolchain/tools/base.py rename to llama_stack/providers/impls/meta_reference/agents/tools/base.py index 324cce0e2..15fba7e2e 100644 --- a/llama_toolchain/tools/base.py +++ b/llama_stack/providers/impls/meta_reference/agents/tools/base.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from typing import List -from llama_toolchain.inference.api import Message +from llama_stack.apis.inference import Message class BaseTool(ABC): diff --git a/llama_toolchain/tools/builtin.py b/llama_stack/providers/impls/meta_reference/agents/tools/builtin.py similarity index 99% rename from llama_toolchain/tools/builtin.py rename to llama_stack/providers/impls/meta_reference/agents/tools/builtin.py index 56fda3723..4c9cdfcd2 100644 --- a/llama_toolchain/tools/builtin.py +++ b/llama_stack/providers/impls/meta_reference/agents/tools/builtin.py @@ -21,8 +21,8 @@ from .ipython_tool.code_execution import ( TOOLS_ATTACHMENT_KEY_REGEX, ) -from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.agentic_system.api import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.agents import * # noqa: F403 from .base import BaseTool diff --git a/llama_toolchain/agentic_system/api/__init__.py b/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/__init__.py similarity index 83% rename from llama_toolchain/agentic_system/api/__init__.py rename to llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/__init__.py index a7e55ba91..756f351d8 100644 --- a/llama_toolchain/agentic_system/api/__init__.py +++ b/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/__init__.py @@ -3,5 +3,3 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/tools/ipython_tool/code_env_prefix.py b/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_env_prefix.py similarity index 100% rename from llama_toolchain/tools/ipython_tool/code_env_prefix.py rename to llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_env_prefix.py diff --git a/llama_toolchain/tools/ipython_tool/code_execution.py b/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_execution.py similarity index 100% rename from llama_toolchain/tools/ipython_tool/code_execution.py rename to llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_execution.py diff --git a/llama_toolchain/tools/ipython_tool/matplotlib_custom_backend.py b/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py similarity index 100% rename from llama_toolchain/tools/ipython_tool/matplotlib_custom_backend.py rename to llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py diff --git a/llama_toolchain/tools/ipython_tool/utils.py b/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/utils.py similarity index 100% rename from llama_toolchain/tools/ipython_tool/utils.py rename to llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/utils.py diff --git a/llama_toolchain/tools/safety.py b/llama_stack/providers/impls/meta_reference/agents/tools/safety.py similarity index 88% rename from llama_toolchain/tools/safety.py rename to llama_stack/providers/impls/meta_reference/agents/tools/safety.py index 24051af8a..d36dc3490 100644 --- a/llama_toolchain/tools/safety.py +++ b/llama_stack/providers/impls/meta_reference/agents/tools/safety.py @@ -6,10 +6,10 @@ from typing import List -from llama_toolchain.agentic_system.meta_reference.safety import ShieldRunnerMixin +from llama_stack.apis.inference import Message +from llama_stack.apis.safety import Safety, ShieldDefinition -from llama_toolchain.inference.api import Message -from llama_toolchain.safety.api import Safety, ShieldDefinition +from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin from .builtin import BaseTool diff --git a/llama_toolchain/inference/meta_reference/__init__.py b/llama_stack/providers/impls/meta_reference/inference/__init__.py similarity index 100% rename from llama_toolchain/inference/meta_reference/__init__.py rename to llama_stack/providers/impls/meta_reference/inference/__init__.py diff --git a/llama_toolchain/inference/meta_reference/config.py b/llama_stack/providers/impls/meta_reference/inference/config.py similarity index 96% rename from llama_toolchain/inference/meta_reference/config.py rename to llama_stack/providers/impls/meta_reference/inference/config.py index a0bbc5820..27943cb2c 100644 --- a/llama_toolchain/inference/meta_reference/config.py +++ b/llama_stack/providers/impls/meta_reference/inference/config.py @@ -11,9 +11,9 @@ from llama_models.datatypes import ModelFamily from llama_models.schema_utils import json_schema_type from llama_models.sku_list import all_registered_models, resolve_model -from pydantic import BaseModel, Field, field_validator +from llama_stack.apis.inference import QuantizationConfig -from llama_toolchain.inference.api import QuantizationConfig +from pydantic import BaseModel, Field, field_validator @json_schema_type diff --git a/llama_toolchain/inference/meta_reference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py similarity index 98% rename from llama_toolchain/inference/meta_reference/generation.py rename to llama_stack/providers/impls/meta_reference/inference/generation.py index d13b9570d..e1643b21a 100644 --- a/llama_toolchain/inference/meta_reference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -28,9 +28,9 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.reference_impl.model import Transformer from llama_models.sku_list import resolve_model +from llama_stack.apis.inference import QuantizationType -from llama_toolchain.common.model_utils import model_local_dir -from llama_toolchain.inference.api import QuantizationType +from llama_stack.distribution.utils.model_utils import model_local_dir from termcolor import cprint from .config import MetaReferenceImplConfig diff --git a/llama_toolchain/inference/meta_reference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py similarity index 96% rename from llama_toolchain/inference/meta_reference/inference.py rename to llama_stack/providers/impls/meta_reference/inference/inference.py index 247c08f23..597a4cb55 100644 --- a/llama_toolchain/inference/meta_reference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -11,7 +11,7 @@ from typing import AsyncIterator, Union from llama_models.llama3.api.datatypes import StopReason from llama_models.sku_list import resolve_model -from llama_toolchain.inference.api import ( +from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseEvent, @@ -21,13 +21,13 @@ from llama_toolchain.inference.api import ( ToolCallDelta, ToolCallParseStatus, ) -from llama_toolchain.inference.prepare_messages import prepare_messages +from llama_stack.providers.utils.inference.prepare_messages import prepare_messages from .config import MetaReferenceImplConfig from .model_parallel import LlamaModelParallelGenerator from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.inference.api import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 # there's a single model parallel process running serving the model. for now, # we don't support multiple concurrent requests to this process. @@ -57,7 +57,7 @@ class MetaReferenceInferenceImpl(Inference): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - tools: Optional[List[ToolDefinition]] = list(), + tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, @@ -70,7 +70,7 @@ class MetaReferenceInferenceImpl(Inference): model=model, messages=messages, sampling_params=sampling_params, - tools=tools, + tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, stream=stream, diff --git a/llama_toolchain/inference/meta_reference/model_parallel.py b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py similarity index 100% rename from llama_toolchain/inference/meta_reference/model_parallel.py rename to llama_stack/providers/impls/meta_reference/inference/model_parallel.py diff --git a/llama_toolchain/inference/meta_reference/parallel_utils.py b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py similarity index 100% rename from llama_toolchain/inference/meta_reference/parallel_utils.py rename to llama_stack/providers/impls/meta_reference/inference/parallel_utils.py diff --git a/llama_toolchain/evaluations/api/__init__.py b/llama_stack/providers/impls/meta_reference/inference/quantization/__init__.py similarity index 83% rename from llama_toolchain/evaluations/api/__init__.py rename to llama_stack/providers/impls/meta_reference/inference/quantization/__init__.py index a7e55ba91..756f351d8 100644 --- a/llama_toolchain/evaluations/api/__init__.py +++ b/llama_stack/providers/impls/meta_reference/inference/quantization/__init__.py @@ -3,5 +3,3 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/inference/quantization/fp8_impls.py b/llama_stack/providers/impls/meta_reference/inference/quantization/fp8_impls.py similarity index 100% rename from llama_toolchain/inference/quantization/fp8_impls.py rename to llama_stack/providers/impls/meta_reference/inference/quantization/fp8_impls.py diff --git a/llama_toolchain/inference/quantization/loader.py b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py similarity index 97% rename from llama_toolchain/inference/quantization/loader.py rename to llama_stack/providers/impls/meta_reference/inference/quantization/loader.py index 54827dce9..9d28c9853 100644 --- a/llama_toolchain/inference/quantization/loader.py +++ b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py @@ -14,9 +14,9 @@ import torch from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from llama_models.llama3.api.model import Transformer, TransformerBlock -from llama_toolchain.inference.api import QuantizationType +from llama_stack.apis.inference import QuantizationType -from llama_toolchain.inference.api.config import ( +from llama_stack.apis.inference.config import ( CheckpointQuantizationFormat, MetaReferenceImplConfig, ) diff --git a/llama_toolchain/batch_inference/api/__init__.py b/llama_stack/providers/impls/meta_reference/inference/quantization/scripts/__init__.py similarity index 83% rename from llama_toolchain/batch_inference/api/__init__.py rename to llama_stack/providers/impls/meta_reference/inference/quantization/scripts/__init__.py index a7e55ba91..756f351d8 100644 --- a/llama_toolchain/batch_inference/api/__init__.py +++ b/llama_stack/providers/impls/meta_reference/inference/quantization/scripts/__init__.py @@ -3,5 +3,3 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/inference/quantization/scripts/build_conda.sh b/llama_stack/providers/impls/meta_reference/inference/quantization/scripts/build_conda.sh similarity index 100% rename from llama_toolchain/inference/quantization/scripts/build_conda.sh rename to llama_stack/providers/impls/meta_reference/inference/quantization/scripts/build_conda.sh diff --git a/llama_toolchain/inference/quantization/scripts/quantize_checkpoint.py b/llama_stack/providers/impls/meta_reference/inference/quantization/scripts/quantize_checkpoint.py similarity index 100% rename from llama_toolchain/inference/quantization/scripts/quantize_checkpoint.py rename to llama_stack/providers/impls/meta_reference/inference/quantization/scripts/quantize_checkpoint.py diff --git a/llama_toolchain/inference/quantization/scripts/run_quantize_checkpoint.sh b/llama_stack/providers/impls/meta_reference/inference/quantization/scripts/run_quantize_checkpoint.sh similarity index 100% rename from llama_toolchain/inference/quantization/scripts/run_quantize_checkpoint.sh rename to llama_stack/providers/impls/meta_reference/inference/quantization/scripts/run_quantize_checkpoint.sh diff --git a/llama_toolchain/inference/quantization/test_fp8.py b/llama_stack/providers/impls/meta_reference/inference/quantization/test_fp8.py similarity index 100% rename from llama_toolchain/inference/quantization/test_fp8.py rename to llama_stack/providers/impls/meta_reference/inference/quantization/test_fp8.py diff --git a/llama_toolchain/memory/meta_reference/faiss/__init__.py b/llama_stack/providers/impls/meta_reference/memory/__init__.py similarity index 100% rename from llama_toolchain/memory/meta_reference/faiss/__init__.py rename to llama_stack/providers/impls/meta_reference/memory/__init__.py diff --git a/llama_toolchain/memory/meta_reference/faiss/config.py b/llama_stack/providers/impls/meta_reference/memory/config.py similarity index 100% rename from llama_toolchain/memory/meta_reference/faiss/config.py rename to llama_stack/providers/impls/meta_reference/memory/config.py diff --git a/llama_toolchain/memory/meta_reference/faiss/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py similarity index 95% rename from llama_toolchain/memory/meta_reference/faiss/faiss.py rename to llama_stack/providers/impls/meta_reference/memory/faiss.py index 2dcff4d25..ee716430e 100644 --- a/llama_toolchain/memory/meta_reference/faiss/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -15,13 +15,14 @@ from numpy.typing import NDArray from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.memory.api import * # noqa: F403 -from llama_toolchain.memory.common.vector_store import ( +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.providers.utils.memory.vector_store import ( ALL_MINILM_L6_V2_DIMENSION, BankWithIndex, EmbeddingIndex, ) -from llama_toolchain.telemetry import tracing +from llama_stack.providers.utils.telemetry import tracing + from .config import FaissImplConfig logger = logging.getLogger(__name__) diff --git a/llama_toolchain/safety/meta_reference/__init__.py b/llama_stack/providers/impls/meta_reference/safety/__init__.py similarity index 100% rename from llama_toolchain/safety/meta_reference/__init__.py rename to llama_stack/providers/impls/meta_reference/safety/__init__.py diff --git a/llama_toolchain/safety/meta_reference/config.py b/llama_stack/providers/impls/meta_reference/safety/config.py similarity index 100% rename from llama_toolchain/safety/meta_reference/config.py rename to llama_stack/providers/impls/meta_reference/safety/config.py diff --git a/llama_toolchain/safety/meta_reference/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py similarity index 96% rename from llama_toolchain/safety/meta_reference/safety.py rename to llama_stack/providers/impls/meta_reference/safety/safety.py index 6c75e74e8..baf0ebb46 100644 --- a/llama_toolchain/safety/meta_reference/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -8,8 +8,8 @@ import asyncio from llama_models.sku_list import resolve_model -from llama_toolchain.common.model_utils import model_local_dir -from llama_toolchain.safety.api import * # noqa +from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.apis.safety import * # noqa from .config import SafetyConfig from .shields import ( diff --git a/llama_toolchain/safety/meta_reference/shields/__init__.py b/llama_stack/providers/impls/meta_reference/safety/shields/__init__.py similarity index 100% rename from llama_toolchain/safety/meta_reference/shields/__init__.py rename to llama_stack/providers/impls/meta_reference/safety/shields/__init__.py diff --git a/llama_toolchain/safety/meta_reference/shields/base.py b/llama_stack/providers/impls/meta_reference/safety/shields/base.py similarity index 97% rename from llama_toolchain/safety/meta_reference/shields/base.py rename to llama_stack/providers/impls/meta_reference/safety/shields/base.py index ed939212d..64e64e2fd 100644 --- a/llama_toolchain/safety/meta_reference/shields/base.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/base.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from typing import List from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message -from llama_toolchain.safety.api import * # noqa: F403 +from llama_stack.apis.safety import * # noqa: F403 CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" diff --git a/llama_toolchain/safety/meta_reference/shields/code_scanner.py b/llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py similarity index 95% rename from llama_toolchain/safety/meta_reference/shields/code_scanner.py rename to llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py index 564d15a53..75ec7c37b 100644 --- a/llama_toolchain/safety/meta_reference/shields/code_scanner.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py @@ -8,7 +8,7 @@ from codeshield.cs import CodeShield from termcolor import cprint from .base import ShieldResponse, TextShield -from llama_toolchain.safety.api import * # noqa: F403 +from llama_stack.apis.safety import * # noqa: F403 class CodeScannerShield(TextShield): diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/contrib/__init__.py b/llama_stack/providers/impls/meta_reference/safety/shields/contrib/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/safety/shields/contrib/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py b/llama_stack/providers/impls/meta_reference/safety/shields/contrib/third_party_shield.py similarity index 93% rename from llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py rename to llama_stack/providers/impls/meta_reference/safety/shields/contrib/third_party_shield.py index 61a5977ed..9aa8adea8 100644 --- a/llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/contrib/third_party_shield.py @@ -8,7 +8,7 @@ from typing import List from llama_models.llama3.api.datatypes import Message -from llama_toolchain.safety.meta_reference.shields.base import ( +from llama_stack.safety.meta_reference.shields.base import ( OnViolationAction, ShieldBase, ShieldResponse, diff --git a/llama_toolchain/safety/meta_reference/shields/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py similarity index 99% rename from llama_toolchain/safety/meta_reference/shields/llama_guard.py rename to llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py index fe04baa00..c5c4f58a6 100644 --- a/llama_toolchain/safety/meta_reference/shields/llama_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py @@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role from transformers import AutoModelForCausalLM, AutoTokenizer from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse -from llama_toolchain.safety.api import * # noqa: F403 +from llama_stack.apis.safety import * # noqa: F403 SAFE_RESPONSE = "safe" _INSTANCE = None diff --git a/llama_toolchain/safety/meta_reference/shields/prompt_guard.py b/llama_stack/providers/impls/meta_reference/safety/shields/prompt_guard.py similarity index 99% rename from llama_toolchain/safety/meta_reference/shields/prompt_guard.py rename to llama_stack/providers/impls/meta_reference/safety/shields/prompt_guard.py index a1097a6f7..67bc6a6db 100644 --- a/llama_toolchain/safety/meta_reference/shields/prompt_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/prompt_guard.py @@ -14,7 +14,7 @@ from termcolor import cprint from transformers import AutoModelForSequenceClassification, AutoTokenizer from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield -from llama_toolchain.safety.api import * # noqa: F403 +from llama_stack.apis.safety import * # noqa: F403 class PromptGuardShield(TextShield): diff --git a/llama_toolchain/telemetry/console/__init__.py b/llama_stack/providers/impls/meta_reference/telemetry/__init__.py similarity index 100% rename from llama_toolchain/telemetry/console/__init__.py rename to llama_stack/providers/impls/meta_reference/telemetry/__init__.py diff --git a/llama_toolchain/telemetry/console/config.py b/llama_stack/providers/impls/meta_reference/telemetry/config.py similarity index 100% rename from llama_toolchain/telemetry/console/config.py rename to llama_stack/providers/impls/meta_reference/telemetry/config.py diff --git a/llama_toolchain/telemetry/console/console.py b/llama_stack/providers/impls/meta_reference/telemetry/console.py similarity index 97% rename from llama_toolchain/telemetry/console/console.py rename to llama_stack/providers/impls/meta_reference/telemetry/console.py index 2e7b9980d..b56c704a6 100644 --- a/llama_toolchain/telemetry/console/console.py +++ b/llama_stack/providers/impls/meta_reference/telemetry/console.py @@ -6,7 +6,7 @@ from typing import Optional -from llama_toolchain.telemetry.api import * # noqa: F403 +from llama_stack.apis.telemetry import * # noqa: F403 from .config import ConsoleConfig diff --git a/llama_stack/providers/impls/sqlite/__init__.py b/llama_stack/providers/impls/sqlite/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/impls/sqlite/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/registry/__init__.py b/llama_stack/providers/registry/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/registry/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/agentic_system/providers.py b/llama_stack/providers/registry/agents.py similarity index 69% rename from llama_toolchain/agentic_system/providers.py rename to llama_stack/providers/registry/agents.py index 79e66d15e..3195c92da 100644 --- a/llama_toolchain/agentic_system/providers.py +++ b/llama_stack/providers/registry/agents.py @@ -6,16 +6,15 @@ from typing import List -from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec +from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( - api=Api.agentic_system, + api=Api.agents, provider_id="meta-reference", pip_packages=[ - "codeshield", "matplotlib", "pillow", "pandas", @@ -23,8 +22,8 @@ def available_providers() -> List[ProviderSpec]: "torch", "transformers", ], - module="llama_toolchain.agentic_system.meta_reference", - config_class="llama_toolchain.agentic_system.meta_reference.MetaReferenceImplConfig", + module="llama_stack.providers.impls.meta_reference.agents", + config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceImplConfig", api_dependencies=[ Api.inference, Api.safety, diff --git a/llama_toolchain/inference/providers.py b/llama_stack/providers/registry/inference.py similarity index 64% rename from llama_toolchain/inference/providers.py rename to llama_stack/providers/registry/inference.py index 928c6ef57..2fa8c98dc 100644 --- a/llama_toolchain/inference/providers.py +++ b/llama_stack/providers/registry/inference.py @@ -6,7 +6,7 @@ from typing import List -from llama_toolchain.core.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 def available_providers() -> List[ProviderSpec]: @@ -17,22 +17,21 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[ "accelerate", "blobfile", - "codeshield", "fairscale", "fbgemm-gpu==0.8.0", "torch", "transformers", "zmq", ], - module="llama_toolchain.inference.meta_reference", - config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig", + module="llama_stack.providers.impls.meta_reference.inference", + config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceImplConfig", ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( adapter_id="ollama", pip_packages=["ollama"], - module="llama_toolchain.inference.adapters.ollama", + module="llama_stack.providers.adapters.inference.ollama", ), ), remote_provider_spec( @@ -40,8 +39,8 @@ def available_providers() -> List[ProviderSpec]: adapter=AdapterSpec( adapter_id="tgi", pip_packages=["huggingface_hub"], - module="llama_toolchain.inference.adapters.tgi", - config_class="llama_toolchain.inference.adapters.tgi.TGIImplConfig", + module="llama_stack.providers.adapters.inference.tgi", + config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig", ), ), remote_provider_spec( @@ -51,8 +50,8 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[ "fireworks-ai", ], - module="llama_toolchain.inference.adapters.fireworks", - config_class="llama_toolchain.inference.adapters.fireworks.FireworksImplConfig", + module="llama_stack.providers.adapters.inference.fireworks", + config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig", ), ), remote_provider_spec( @@ -62,8 +61,8 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[ "together", ], - module="llama_toolchain.inference.adapters.together", - config_class="llama_toolchain.inference.adapters.together.TogetherImplConfig", + module="llama_stack.providers.adapters.inference.together", + config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig", ), ), ] diff --git a/llama_toolchain/memory/providers.py b/llama_stack/providers/registry/memory.py similarity index 58% rename from llama_toolchain/memory/providers.py rename to llama_stack/providers/registry/memory.py index d3336278a..12487567a 100644 --- a/llama_toolchain/memory/providers.py +++ b/llama_stack/providers/registry/memory.py @@ -6,7 +6,7 @@ from typing import List -from llama_toolchain.core.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 EMBEDDING_DEPS = [ "blobfile", @@ -20,26 +20,26 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.memory, - provider_id="meta-reference-faiss", + provider_id="meta-reference", pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], - module="llama_toolchain.memory.meta_reference.faiss", - config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig", + module="llama_stack.providers.impls.meta_reference.memory", + config_class="llama_stack.providers.impls.meta_reference.memory.FaissImplConfig", ), remote_provider_spec( - api=Api.memory, - adapter=AdapterSpec( + Api.memory, + AdapterSpec( adapter_id="chromadb", pip_packages=EMBEDDING_DEPS + ["chromadb-client"], - module="llama_toolchain.memory.adapters.chroma", + module="llama_stack.providers.adapters.memory.chroma", ), ), remote_provider_spec( - api=Api.memory, - adapter=AdapterSpec( + Api.memory, + AdapterSpec( adapter_id="pgvector", pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"], - module="llama_toolchain.memory.adapters.pgvector", - config_class="llama_toolchain.memory.adapters.pgvector.PGVectorConfig", + module="llama_stack.providers.adapters.memory.pgvector", + config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig", ), ), ] diff --git a/llama_toolchain/safety/providers.py b/llama_stack/providers/registry/safety.py similarity index 69% rename from llama_toolchain/safety/providers.py rename to llama_stack/providers/registry/safety.py index c523e628e..6e9583066 100644 --- a/llama_toolchain/safety/providers.py +++ b/llama_stack/providers/registry/safety.py @@ -6,7 +6,7 @@ from typing import List -from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec +from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec def available_providers() -> List[ProviderSpec]: @@ -20,7 +20,7 @@ def available_providers() -> List[ProviderSpec]: "torch", "transformers", ], - module="llama_toolchain.safety.meta_reference", - config_class="llama_toolchain.safety.meta_reference.SafetyConfig", + module="llama_stack.providers.impls.meta_reference.safety", + config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig", ), ] diff --git a/llama_toolchain/telemetry/providers.py b/llama_stack/providers/registry/telemetry.py similarity index 58% rename from llama_toolchain/telemetry/providers.py rename to llama_stack/providers/registry/telemetry.py index 00038e569..29c57fd86 100644 --- a/llama_toolchain/telemetry/providers.py +++ b/llama_stack/providers/registry/telemetry.py @@ -6,16 +6,16 @@ from typing import List -from llama_toolchain.core.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.telemetry, - provider_id="console", + provider_id="meta-reference", pip_packages=[], - module="llama_toolchain.telemetry.console", - config_class="llama_toolchain.telemetry.console.ConsoleConfig", + module="llama_stack.providers.impls.meta_reference.telemetry", + config_class="llama_stack.providers.impls.meta_reference.telemetry.ConsoleConfig", ), ] diff --git a/llama_stack/providers/routers/__init__.py b/llama_stack/providers/routers/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/routers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/routers/memory/__init__.py b/llama_stack/providers/routers/memory/__init__.py new file mode 100644 index 000000000..d4dbbb1d4 --- /dev/null +++ b/llama_stack/providers/routers/memory/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, List, Tuple + +from llama_stack.distribution.datatypes import Api + + +async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]): + from .memory import MemoryRouterImpl + + impl = MemoryRouterImpl(inner_impls, deps) + await impl.initialize() + return impl diff --git a/llama_stack/providers/routers/memory/memory.py b/llama_stack/providers/routers/memory/memory.py new file mode 100644 index 000000000..b96cde626 --- /dev/null +++ b/llama_stack/providers/routers/memory/memory.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Tuple + +from llama_stack.distribution.datatypes import Api +from llama_stack.apis.memory import * # noqa: F403 + + +class MemoryRouterImpl(Memory): + """Routes to an provider based on the memory bank type""" + + def __init__( + self, + inner_impls: List[Tuple[str, Any]], + deps: List[Api], + ) -> None: + self.deps = deps + + bank_types = [v.value for v in MemoryBankType] + + self.providers = {} + for routing_key, provider_impl in inner_impls: + if routing_key not in bank_types: + raise ValueError( + f"Unknown routing key `{routing_key}` for memory bank type" + ) + self.providers[routing_key] = provider_impl + + self.bank_id_to_type = {} + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + for p in self.providers.values(): + await p.shutdown() + + def get_provider(self, bank_type): + if bank_type not in self.providers: + raise ValueError(f"Memory bank type {bank_type} not supported") + + return self.providers[bank_type] + + async def create_memory_bank( + self, + name: str, + config: MemoryBankConfig, + url: Optional[URL] = None, + ) -> MemoryBank: + provider = self.get_provider(config.type) + bank = await provider.create_memory_bank(name, config, url) + self.bank_id_to_type[bank.bank_id] = config.type + return bank + + async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: + bank_type = self.bank_id_to_type.get(bank_id) + if not bank_type: + raise ValueError(f"Could not find bank type for {bank_id}") + + provider = self.get_provider(bank_type) + return await provider.get_memory_bank(bank_id) + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + bank_type = self.bank_id_to_type.get(bank_id) + if not bank_type: + raise ValueError(f"Could not find bank type for {bank_id}") + + provider = self.get_provider(bank_type) + return await provider.insert_documents(bank_id, documents, ttl_seconds) + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + bank_type = self.bank_id_to_type.get(bank_id) + if not bank_type: + raise ValueError(f"Could not find bank type for {bank_id}") + + provider = self.get_provider(bank_type) + return await provider.query_documents(bank_id, query, params) diff --git a/llama_stack/providers/utils/__init__.py b/llama_stack/providers/utils/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/utils/inference/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/inference/prepare_messages.py b/llama_stack/providers/utils/inference/prepare_messages.py similarity index 97% rename from llama_toolchain/inference/prepare_messages.py rename to llama_stack/providers/utils/inference/prepare_messages.py index 92e94f8d2..0519cbfab 100644 --- a/llama_toolchain/inference/prepare_messages.py +++ b/llama_stack/providers/utils/inference/prepare_messages.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.inference.api import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 from llama_models.llama3.prompt_templates import ( BuiltinToolGenerator, FunctionTagCustomToolGenerator, diff --git a/llama_stack/providers/utils/memory/__init__.py b/llama_stack/providers/utils/memory/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/utils/memory/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/memory/common/file_utils.py b/llama_stack/providers/utils/memory/file_utils.py similarity index 100% rename from llama_toolchain/memory/common/file_utils.py rename to llama_stack/providers/utils/memory/file_utils.py diff --git a/llama_toolchain/memory/common/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py similarity index 98% rename from llama_toolchain/memory/common/vector_store.py rename to llama_stack/providers/utils/memory/vector_store.py index baa3fbf21..d575a985b 100644 --- a/llama_toolchain/memory/common/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -20,7 +20,7 @@ from pypdf import PdfReader from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer -from llama_toolchain.memory.api import * # noqa: F403 +from llama_stack.apis.memory import * # noqa: F403 ALL_MINILM_L6_V2_DIMENSION = 384 diff --git a/llama_stack/providers/utils/telemetry/__init__.py b/llama_stack/providers/utils/telemetry/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/utils/telemetry/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py similarity index 99% rename from llama_toolchain/telemetry/tracing.py rename to llama_stack/providers/utils/telemetry/tracing.py index 6afe5c2fb..5284dfac0 100644 --- a/llama_toolchain/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -15,7 +15,7 @@ from functools import wraps from typing import Any, Dict, List -from llama_toolchain.telemetry.api import * # noqa: F403 +from llama_stack.apis.telemetry import * # noqa: F403 def generate_short_uuid(len: int = 12): diff --git a/llama_toolchain/agentic_system/__init__.py b/llama_toolchain/agentic_system/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/llama_toolchain/agentic_system/execute_with_custom_tools.py b/llama_toolchain/agentic_system/execute_with_custom_tools.py deleted file mode 100644 index e8038bc20..000000000 --- a/llama_toolchain/agentic_system/execute_with_custom_tools.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import AsyncGenerator, List - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.agentic_system.api import * # noqa: F403 -from llama_toolchain.memory.api import * # noqa: F403 -from llama_toolchain.safety.api import * # noqa: F403 - -from llama_toolchain.agentic_system.api import ( - AgenticSystemTurnResponseEventType as EventType, -) -from llama_toolchain.tools.custom.datatypes import CustomTool - - -class AgentWithCustomToolExecutor: - def __init__( - self, - api: AgenticSystem, - agent_id: str, - session_id: str, - agent_config: AgentConfig, - custom_tools: List[CustomTool], - ): - self.api = api - self.agent_id = agent_id - self.session_id = session_id - self.agent_config = agent_config - self.custom_tools = custom_tools - - async def execute_turn( - self, - messages: List[Message], - attachments: Optional[List[Attachment]] = None, - max_iters: int = 5, - stream: bool = True, - ) -> AsyncGenerator: - tools_dict = {t.get_name(): t for t in self.custom_tools} - - current_messages = messages.copy() - n_iter = 0 - while n_iter < max_iters: - n_iter += 1 - - request = AgenticSystemTurnCreateRequest( - agent_id=self.agent_id, - session_id=self.session_id, - messages=current_messages, - attachments=attachments, - stream=stream, - ) - - turn = None - async for chunk in self.api.create_agentic_system_turn(request): - if chunk.event.payload.event_type != EventType.turn_complete.value: - yield chunk - else: - turn = chunk.event.payload.turn - - message = turn.output_message - if len(message.tool_calls) == 0: - yield chunk - return - - if message.stop_reason == StopReason.out_of_tokens: - yield chunk - return - - tool_call = message.tool_calls[0] - if tool_call.tool_name not in tools_dict: - m = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else", - ) - next_message = m - else: - tool = tools_dict[tool_call.tool_name] - result_messages = await execute_custom_tool(tool, message) - next_message = result_messages[0] - - yield next_message - current_messages = [next_message] - - -async def execute_custom_tool(tool: CustomTool, message: Message) -> List[Message]: - result_messages = await tool.run([message]) - assert ( - len(result_messages) == 1 - ), f"Expected single message, got {len(result_messages)}" - - return result_messages diff --git a/llama_toolchain/agentic_system/meta_reference/config.py b/llama_toolchain/agentic_system/meta_reference/config.py deleted file mode 100644 index f1a92f2e7..000000000 --- a/llama_toolchain/agentic_system/meta_reference/config.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Optional - -from pydantic import BaseModel - - -class MetaReferenceImplConfig(BaseModel): - brave_search_api_key: Optional[str] = None - bing_search_api_key: Optional[str] = None - wolfram_api_key: Optional[str] = None diff --git a/llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml b/llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml deleted file mode 100644 index 2a25cb9dd..000000000 --- a/llama_toolchain/configs/distributions/conda/local-conda-example-build.yaml +++ /dev/null @@ -1,10 +0,0 @@ -name: local-conda-example -distribution_spec: - description: Use code from `llama_toolchain` itself to serve all llama stack APIs - providers: - inference: meta-reference - memory: meta-reference-faiss - safety: meta-reference - agentic_system: meta-reference - telemetry: console -image_type: conda diff --git a/llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml b/llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml deleted file mode 100644 index 0bdb18802..000000000 --- a/llama_toolchain/configs/distributions/docker/local-docker-example-build.yaml +++ /dev/null @@ -1,10 +0,0 @@ -name: local-docker-example -distribution_spec: - description: Use code from `llama_toolchain` itself to serve all llama stack APIs - providers: - inference: meta-reference - memory: meta-reference-faiss - safety: meta-reference - agentic_system: meta-reference - telemetry: console -image_type: docker diff --git a/llama_toolchain/core/configure.py b/llama_toolchain/core/configure.py deleted file mode 100644 index 7f9aa0140..000000000 --- a/llama_toolchain/core/configure.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict - -from llama_toolchain.core.datatypes import * # noqa: F403 -from termcolor import cprint - -from llama_toolchain.common.prompt_for_config import prompt_for_config -from llama_toolchain.core.distribution import api_providers -from llama_toolchain.core.dynamic import instantiate_class_type - - -def configure_api_providers(existing_configs: Dict[str, Any]) -> None: - all_providers = api_providers() - - provider_configs = {} - for api_str, stub_config in existing_configs.items(): - api = Api(api_str) - providers = all_providers[api] - provider_id = stub_config["provider_id"] - if provider_id not in providers: - raise ValueError( - f"Unknown provider `{provider_id}` is not available for API `{api_str}`" - ) - - provider_spec = providers[provider_id] - cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"]) - config_type = instantiate_class_type(provider_spec.config_class) - - try: - existing_provider_config = config_type(**stub_config) - except Exception: - existing_provider_config = None - - provider_config = prompt_for_config( - config_type, - existing_provider_config, - ) - print("") - - provider_configs[api_str] = { - "provider_id": provider_id, - **provider_config.dict(), - } - - return provider_configs diff --git a/llama_toolchain/core/dynamic.py b/llama_toolchain/core/dynamic.py deleted file mode 100644 index adb9b5dac..000000000 --- a/llama_toolchain/core/dynamic.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import asyncio -import importlib -from typing import Any, Dict - -from .datatypes import ProviderSpec, RemoteProviderSpec - - -def instantiate_class_type(fully_qualified_name): - module_name, class_name = fully_qualified_name.rsplit(".", 1) - module = importlib.import_module(module_name) - return getattr(module, class_name) - - -# returns a class implementing the protocol corresponding to the Api -def instantiate_provider( - provider_spec: ProviderSpec, - provider_config: Dict[str, Any], - deps: Dict[str, ProviderSpec], -): - module = importlib.import_module(provider_spec.module) - - config_type = instantiate_class_type(provider_spec.config_class) - if isinstance(provider_spec, RemoteProviderSpec): - if provider_spec.adapter: - method = "get_adapter_impl" - else: - method = "get_client_impl" - else: - method = "get_provider_impl" - - config = config_type(**provider_config) - fn = getattr(module, method) - impl = asyncio.run(fn(config, deps)) - impl.__provider_spec__ = provider_spec - impl.__provider_config__ = config - return impl diff --git a/llama_toolchain/inference/api/__init__.py b/llama_toolchain/inference/api/__init__.py deleted file mode 100644 index a7e55ba91..000000000 --- a/llama_toolchain/inference/api/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/memory/api/__init__.py b/llama_toolchain/memory/api/__init__.py deleted file mode 100644 index a7e55ba91..000000000 --- a/llama_toolchain/memory/api/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/post_training/api/__init__.py b/llama_toolchain/post_training/api/__init__.py deleted file mode 100644 index a7e55ba91..000000000 --- a/llama_toolchain/post_training/api/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/reward_scoring/api/__init__.py b/llama_toolchain/reward_scoring/api/__init__.py deleted file mode 100644 index a7e55ba91..000000000 --- a/llama_toolchain/reward_scoring/api/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/safety/api/__init__.py b/llama_toolchain/safety/api/__init__.py deleted file mode 100644 index a7e55ba91..000000000 --- a/llama_toolchain/safety/api/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/stack.py b/llama_toolchain/stack.py deleted file mode 100644 index 1e2976ab3..000000000 --- a/llama_toolchain/stack.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.agentic_system.api import * # noqa: F403 -from llama_toolchain.dataset.api import * # noqa: F403 -from llama_toolchain.evaluations.api import * # noqa: F403 -from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.batch_inference.api import * # noqa: F403 -from llama_toolchain.memory.api import * # noqa: F403 -from llama_toolchain.telemetry.api import * # noqa: F403 -from llama_toolchain.post_training.api import * # noqa: F403 -from llama_toolchain.reward_scoring.api import * # noqa: F403 -from llama_toolchain.synthetic_data_generation.api import * # noqa: F403 -from llama_toolchain.safety.api import * # noqa: F403 - - -class LlamaStack( - Inference, - BatchInference, - AgenticSystem, - RewardScoring, - Safety, - SyntheticDataGeneration, - Datasets, - Telemetry, - PostTraining, - Memory, - Evaluations, -): - pass diff --git a/llama_toolchain/synthetic_data_generation/api/__init__.py b/llama_toolchain/synthetic_data_generation/api/__init__.py deleted file mode 100644 index a7e55ba91..000000000 --- a/llama_toolchain/synthetic_data_generation/api/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/telemetry/api/__init__.py b/llama_toolchain/telemetry/api/__init__.py deleted file mode 100644 index a7e55ba91..000000000 --- a/llama_toolchain/telemetry/api/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/tools/custom/datatypes.py b/llama_toolchain/tools/custom/datatypes.py deleted file mode 100644 index 05b142d6f..000000000 --- a/llama_toolchain/tools/custom/datatypes.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json - -from abc import abstractmethod -from typing import Dict, List - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.agentic_system.api import * # noqa: F403 - - -class CustomTool: - """ - Developers can define their custom tools that models can use - by extending this class. - - Developers need to provide - - name - - description - - params_definition - - implement tool's behavior in `run_impl` method - - NOTE: The return of the `run` method needs to be json serializable - """ - - @abstractmethod - def get_name(self) -> str: - raise NotImplementedError - - @abstractmethod - def get_description(self) -> str: - raise NotImplementedError - - @abstractmethod - def get_params_definition(self) -> Dict[str, ToolParamDefinition]: - raise NotImplementedError - - def get_instruction_string(self) -> str: - return f"Use the function '{self.get_name()}' to: {self.get_description()}" - - def parameters_for_system_prompt(self) -> str: - return json.dumps( - { - "name": self.get_name(), - "description": self.get_description(), - "parameters": { - name: definition.__dict__ - for name, definition in self.get_params_definition().items() - }, - } - ) - - def get_tool_definition(self) -> FunctionCallToolDefinition: - return FunctionCallToolDefinition( - function_name=self.get_name(), - description=self.get_description(), - parameters=self.get_params_definition(), - ) - - @abstractmethod - async def run(self, messages: List[Message]) -> List[Message]: - raise NotImplementedError - - -class SingleMessageCustomTool(CustomTool): - """ - Helper class to handle custom tools that take a single message - Extending this class and implementing the `run_impl` method will - allow for the tool be called by the model and the necessary plumbing. - """ - - async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: - assert len(messages) == 1, "Expected single message" - - message = messages[0] - - tool_call = message.tool_calls[0] - - try: - response = await self.run_impl(**tool_call.arguments) - response_str = json.dumps(response, ensure_ascii=False) - except Exception as e: - response_str = f"Error when running tool: {e}" - - message = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=response_str, - ) - return [message] - - @abstractmethod - async def run_impl(self, *args, **kwargs): - raise NotImplementedError() diff --git a/requirements.txt b/requirements.txt index 6dc053c63..c84ccb870 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ fire httpx huggingface-hub llama-models>=0.0.17 +python-dotenv pydantic requests termcolor diff --git a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html index 6e7fe287f..d3f6f593b 100644 --- a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html +++ b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-11 16:05:23.016090" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-17 12:55:45.538053" }, "servers": [ { @@ -209,7 +209,7 @@ } } }, - "/agentic_system/create": { + "/agents/create": { "post": { "responses": { "200": { @@ -217,21 +217,21 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/AgenticSystemCreateResponse" + "$ref": "#/components/schemas/AgentCreateResponse" } } } } }, "tags": [ - "AgenticSystem" + "Agents" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateAgenticSystemRequest" + "$ref": "#/components/schemas/CreateAgentRequest" } } }, @@ -239,7 +239,7 @@ } } }, - "/agentic_system/session/create": { + "/agents/session/create": { "post": { "responses": { "200": { @@ -247,21 +247,21 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/AgenticSystemSessionCreateResponse" + "$ref": "#/components/schemas/AgentSessionCreateResponse" } } } } }, "tags": [ - "AgenticSystem" + "Agents" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateAgenticSystemSessionRequest" + "$ref": "#/components/schemas/CreateAgentSessionRequest" } } }, @@ -269,29 +269,29 @@ } } }, - "/agentic_system/turn/create": { + "/agents/turn/create": { "post": { "responses": { "200": { "description": "OK", "content": { - "text/event-stream": { + "application/json": { "schema": { - "$ref": "#/components/schemas/AgenticSystemTurnResponseStreamChunk" + "$ref": "#/components/schemas/AgentTurnResponseStreamChunk" } } } } }, "tags": [ - "AgenticSystem" + "Agents" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateAgenticSystemTurnRequest" + "$ref": "#/components/schemas/CreateAgentTurnRequest" } } }, @@ -352,7 +352,7 @@ } } }, - "/agentic_system/delete": { + "/agents/delete": { "post": { "responses": { "200": { @@ -360,14 +360,14 @@ } }, "tags": [ - "AgenticSystem" + "Agents" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/DeleteAgenticSystemRequest" + "$ref": "#/components/schemas/DeleteAgentsRequest" } } }, @@ -375,7 +375,7 @@ } } }, - "/agentic_system/session/delete": { + "/agents/session/delete": { "post": { "responses": { "200": { @@ -383,14 +383,14 @@ } }, "tags": [ - "AgenticSystem" + "Agents" ], "parameters": [], "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/DeleteAgenticSystemSessionRequest" + "$ref": "#/components/schemas/DeleteAgentsSessionRequest" } } }, @@ -594,7 +594,7 @@ } } }, - "/agentic_system/session/get": { + "/agents/session/get": { "post": { "responses": { "200": { @@ -609,7 +609,7 @@ } }, "tags": [ - "AgenticSystem" + "Agents" ], "parameters": [ { @@ -633,7 +633,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/GetAgenticSystemSessionRequest" + "$ref": "#/components/schemas/GetAgentsSessionRequest" } } }, @@ -641,7 +641,7 @@ } } }, - "/agentic_system/step/get": { + "/agents/step/get": { "get": { "responses": { "200": { @@ -649,14 +649,14 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/AgenticSystemStepResponse" + "$ref": "#/components/schemas/AgentStepResponse" } } } } }, "tags": [ - "AgenticSystem" + "Agents" ], "parameters": [ { @@ -686,7 +686,7 @@ ] } }, - "/agentic_system/turn/get": { + "/agents/turn/get": { "get": { "responses": { "200": { @@ -701,7 +701,7 @@ } }, "tags": [ - "AgenticSystem" + "Agents" ], "parameters": [ { @@ -2672,7 +2672,7 @@ "type" ] }, - "CreateAgenticSystemRequest": { + "CreateAgentRequest": { "type": "object", "properties": { "agent_config": { @@ -2684,7 +2684,7 @@ "agent_config" ] }, - "AgenticSystemCreateResponse": { + "AgentCreateResponse": { "type": "object", "properties": { "agent_id": { @@ -2696,7 +2696,7 @@ "agent_id" ] }, - "CreateAgenticSystemSessionRequest": { + "CreateAgentSessionRequest": { "type": "object", "properties": { "agent_id": { @@ -2712,7 +2712,7 @@ "session_name" ] }, - "AgenticSystemSessionCreateResponse": { + "AgentSessionCreateResponse": { "type": "object", "properties": { "session_id": { @@ -2753,7 +2753,7 @@ "mime_type" ] }, - "CreateAgenticSystemTurnRequest": { + "CreateAgentTurnRequest": { "type": "object", "properties": { "agent_id": { @@ -2792,25 +2792,25 @@ "messages" ] }, - "AgenticSystemTurnResponseEvent": { + "AgentTurnResponseEvent": { "type": "object", "properties": { "payload": { "oneOf": [ { - "$ref": "#/components/schemas/AgenticSystemTurnResponseStepStartPayload" + "$ref": "#/components/schemas/AgentTurnResponseStepStartPayload" }, { - "$ref": "#/components/schemas/AgenticSystemTurnResponseStepProgressPayload" + "$ref": "#/components/schemas/AgentTurnResponseStepProgressPayload" }, { - "$ref": "#/components/schemas/AgenticSystemTurnResponseStepCompletePayload" + "$ref": "#/components/schemas/AgentTurnResponseStepCompletePayload" }, { - "$ref": "#/components/schemas/AgenticSystemTurnResponseTurnStartPayload" + "$ref": "#/components/schemas/AgentTurnResponseTurnStartPayload" }, { - "$ref": "#/components/schemas/AgenticSystemTurnResponseTurnCompletePayload" + "$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload" } ] } @@ -2821,7 +2821,7 @@ ], "title": "Streamed agent execution response." }, - "AgenticSystemTurnResponseStepCompletePayload": { + "AgentTurnResponseStepCompletePayload": { "type": "object", "properties": { "event_type": { @@ -2861,7 +2861,7 @@ "step_details" ] }, - "AgenticSystemTurnResponseStepProgressPayload": { + "AgentTurnResponseStepProgressPayload": { "type": "object", "properties": { "event_type": { @@ -2897,7 +2897,7 @@ "step_id" ] }, - "AgenticSystemTurnResponseStepStartPayload": { + "AgentTurnResponseStepStartPayload": { "type": "object", "properties": { "event_type": { @@ -2949,11 +2949,11 @@ "step_id" ] }, - "AgenticSystemTurnResponseStreamChunk": { + "AgentTurnResponseStreamChunk": { "type": "object", "properties": { "event": { - "$ref": "#/components/schemas/AgenticSystemTurnResponseEvent" + "$ref": "#/components/schemas/AgentTurnResponseEvent" } }, "additionalProperties": false, @@ -2961,7 +2961,7 @@ "event" ] }, - "AgenticSystemTurnResponseTurnCompletePayload": { + "AgentTurnResponseTurnCompletePayload": { "type": "object", "properties": { "event_type": { @@ -2978,7 +2978,7 @@ "turn" ] }, - "AgenticSystemTurnResponseTurnStartPayload": { + "AgentTurnResponseTurnStartPayload": { "type": "object", "properties": { "event_type": { @@ -3532,7 +3532,7 @@ "config" ] }, - "DeleteAgenticSystemRequest": { + "DeleteAgentsRequest": { "type": "object", "properties": { "agent_id": { @@ -3544,7 +3544,7 @@ "agent_id" ] }, - "DeleteAgenticSystemSessionRequest": { + "DeleteAgentsSessionRequest": { "type": "object", "properties": { "agent_id": { @@ -3720,7 +3720,7 @@ "metrics" ] }, - "GetAgenticSystemSessionRequest": { + "GetAgentsSessionRequest": { "type": "object", "properties": { "turn_ids": { @@ -3764,7 +3764,7 @@ ], "title": "A single session of an interaction with an Agentic System." }, - "AgenticSystemStepResponse": { + "AgentStepResponse": { "type": "object", "properties": { "step": { @@ -3859,7 +3859,6 @@ "required": [ "document_id", "content", - "mime_type", "metadata" ] }, @@ -5142,37 +5141,37 @@ ], "tags": [ { - "name": "SyntheticDataGeneration" - }, - { - "name": "Datasets" - }, - { - "name": "Evaluations" + "name": "Agents" }, { "name": "Safety" }, { - "name": "Inference" + "name": "SyntheticDataGeneration" }, { "name": "Telemetry" }, { - "name": "PostTraining" - }, - { - "name": "Memory" + "name": "Datasets" }, { "name": "RewardScoring" }, + { + "name": "Evaluations" + }, + { + "name": "PostTraining" + }, + { + "name": "Inference" + }, { "name": "BatchInference" }, { - "name": "AgenticSystem" + "name": "Memory" }, { "name": "BuiltinTool", @@ -5343,56 +5342,56 @@ "description": "" }, { - "name": "CreateAgenticSystemRequest", - "description": "" + "name": "CreateAgentRequest", + "description": "" }, { - "name": "AgenticSystemCreateResponse", - "description": "" + "name": "AgentCreateResponse", + "description": "" }, { - "name": "CreateAgenticSystemSessionRequest", - "description": "" + "name": "CreateAgentSessionRequest", + "description": "" }, { - "name": "AgenticSystemSessionCreateResponse", - "description": "" + "name": "AgentSessionCreateResponse", + "description": "" }, { "name": "Attachment", "description": "" }, { - "name": "CreateAgenticSystemTurnRequest", - "description": "" + "name": "CreateAgentTurnRequest", + "description": "" }, { - "name": "AgenticSystemTurnResponseEvent", - "description": "Streamed agent execution response.\n\n" + "name": "AgentTurnResponseEvent", + "description": "Streamed agent execution response.\n\n" }, { - "name": "AgenticSystemTurnResponseStepCompletePayload", - "description": "" + "name": "AgentTurnResponseStepCompletePayload", + "description": "" }, { - "name": "AgenticSystemTurnResponseStepProgressPayload", - "description": "" + "name": "AgentTurnResponseStepProgressPayload", + "description": "" }, { - "name": "AgenticSystemTurnResponseStepStartPayload", - "description": "" + "name": "AgentTurnResponseStepStartPayload", + "description": "" }, { - "name": "AgenticSystemTurnResponseStreamChunk", - "description": "" + "name": "AgentTurnResponseStreamChunk", + "description": "" }, { - "name": "AgenticSystemTurnResponseTurnCompletePayload", - "description": "" + "name": "AgentTurnResponseTurnCompletePayload", + "description": "" }, { - "name": "AgenticSystemTurnResponseTurnStartPayload", - "description": "" + "name": "AgentTurnResponseTurnStartPayload", + "description": "" }, { "name": "InferenceStep", @@ -5443,12 +5442,12 @@ "description": "" }, { - "name": "DeleteAgenticSystemRequest", - "description": "" + "name": "DeleteAgentsRequest", + "description": "" }, { - "name": "DeleteAgenticSystemSessionRequest", - "description": "" + "name": "DeleteAgentsSessionRequest", + "description": "" }, { "name": "DeleteDatasetRequest", @@ -5487,16 +5486,16 @@ "description": "" }, { - "name": "GetAgenticSystemSessionRequest", - "description": "" + "name": "GetAgentsSessionRequest", + "description": "" }, { "name": "Session", "description": "A single session of an interaction with an Agentic System.\n\n" }, { - "name": "AgenticSystemStepResponse", - "description": "" + "name": "AgentStepResponse", + "description": "" }, { "name": "GetDocumentsRequest", @@ -5675,7 +5674,7 @@ { "name": "Operations", "tags": [ - "AgenticSystem", + "Agents", "BatchInference", "Datasets", "Evaluations", @@ -5692,16 +5691,16 @@ "name": "Types", "tags": [ "AgentConfig", - "AgenticSystemCreateResponse", - "AgenticSystemSessionCreateResponse", - "AgenticSystemStepResponse", - "AgenticSystemTurnResponseEvent", - "AgenticSystemTurnResponseStepCompletePayload", - "AgenticSystemTurnResponseStepProgressPayload", - "AgenticSystemTurnResponseStepStartPayload", - "AgenticSystemTurnResponseStreamChunk", - "AgenticSystemTurnResponseTurnCompletePayload", - "AgenticSystemTurnResponseTurnStartPayload", + "AgentCreateResponse", + "AgentSessionCreateResponse", + "AgentStepResponse", + "AgentTurnResponseEvent", + "AgentTurnResponseStepCompletePayload", + "AgentTurnResponseStepProgressPayload", + "AgentTurnResponseStepStartPayload", + "AgentTurnResponseStreamChunk", + "AgentTurnResponseTurnCompletePayload", + "AgentTurnResponseTurnStartPayload", "Attachment", "BatchChatCompletionRequest", "BatchChatCompletionResponse", @@ -5722,14 +5721,14 @@ "CompletionRequest", "CompletionResponse", "CompletionResponseStreamChunk", - "CreateAgenticSystemRequest", - "CreateAgenticSystemSessionRequest", - "CreateAgenticSystemTurnRequest", + "CreateAgentRequest", + "CreateAgentSessionRequest", + "CreateAgentTurnRequest", "CreateDatasetRequest", "CreateMemoryBankRequest", "DPOAlignmentConfig", - "DeleteAgenticSystemRequest", - "DeleteAgenticSystemSessionRequest", + "DeleteAgentsRequest", + "DeleteAgentsSessionRequest", "DeleteDatasetRequest", "DeleteDocumentsRequest", "DialogGenerations", @@ -5746,7 +5745,7 @@ "EvaluationJobStatusResponse", "FinetuningAlgorithm", "FunctionCallToolDefinition", - "GetAgenticSystemSessionRequest", + "GetAgentsSessionRequest", "GetDocumentsRequest", "InferenceStep", "InsertDocumentsRequest", diff --git a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml index 4d1b27bb7..e96142b00 100644 --- a/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml +++ b/rfcs/RFC-0001-llama-stack-assets/llama-stack-spec.yaml @@ -152,7 +152,7 @@ components: - model - instructions type: object - AgenticSystemCreateResponse: + AgentCreateResponse: additionalProperties: false properties: agent_id: @@ -160,7 +160,7 @@ components: required: - agent_id type: object - AgenticSystemSessionCreateResponse: + AgentSessionCreateResponse: additionalProperties: false properties: session_id: @@ -168,7 +168,7 @@ components: required: - session_id type: object - AgenticSystemStepResponse: + AgentStepResponse: additionalProperties: false properties: step: @@ -180,21 +180,21 @@ components: required: - step type: object - AgenticSystemTurnResponseEvent: + AgentTurnResponseEvent: additionalProperties: false properties: payload: oneOf: - - $ref: '#/components/schemas/AgenticSystemTurnResponseStepStartPayload' - - $ref: '#/components/schemas/AgenticSystemTurnResponseStepProgressPayload' - - $ref: '#/components/schemas/AgenticSystemTurnResponseStepCompletePayload' - - $ref: '#/components/schemas/AgenticSystemTurnResponseTurnStartPayload' - - $ref: '#/components/schemas/AgenticSystemTurnResponseTurnCompletePayload' + - $ref: '#/components/schemas/AgentTurnResponseStepStartPayload' + - $ref: '#/components/schemas/AgentTurnResponseStepProgressPayload' + - $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload' + - $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload' + - $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload' required: - payload title: Streamed agent execution response. type: object - AgenticSystemTurnResponseStepCompletePayload: + AgentTurnResponseStepCompletePayload: additionalProperties: false properties: event_type: @@ -218,7 +218,7 @@ components: - step_type - step_details type: object - AgenticSystemTurnResponseStepProgressPayload: + AgentTurnResponseStepProgressPayload: additionalProperties: false properties: event_type: @@ -244,7 +244,7 @@ components: - step_type - step_id type: object - AgenticSystemTurnResponseStepStartPayload: + AgentTurnResponseStepStartPayload: additionalProperties: false properties: event_type: @@ -274,15 +274,15 @@ components: - step_type - step_id type: object - AgenticSystemTurnResponseStreamChunk: + AgentTurnResponseStreamChunk: additionalProperties: false properties: event: - $ref: '#/components/schemas/AgenticSystemTurnResponseEvent' + $ref: '#/components/schemas/AgentTurnResponseEvent' required: - event type: object - AgenticSystemTurnResponseTurnCompletePayload: + AgentTurnResponseTurnCompletePayload: additionalProperties: false properties: event_type: @@ -294,7 +294,7 @@ components: - event_type - turn type: object - AgenticSystemTurnResponseTurnStartPayload: + AgentTurnResponseTurnStartPayload: additionalProperties: false properties: event_type: @@ -617,7 +617,7 @@ components: - delta title: streamed completion response. type: object - CreateAgenticSystemRequest: + CreateAgentRequest: additionalProperties: false properties: agent_config: @@ -625,7 +625,7 @@ components: required: - agent_config type: object - CreateAgenticSystemSessionRequest: + CreateAgentSessionRequest: additionalProperties: false properties: agent_id: @@ -636,7 +636,7 @@ components: - agent_id - session_name type: object - CreateAgenticSystemTurnRequest: + CreateAgentTurnRequest: additionalProperties: false properties: agent_id: @@ -741,7 +741,7 @@ components: - epsilon - gamma type: object - DeleteAgenticSystemRequest: + DeleteAgentsRequest: additionalProperties: false properties: agent_id: @@ -749,7 +749,7 @@ components: required: - agent_id type: object - DeleteAgenticSystemSessionRequest: + DeleteAgentsSessionRequest: additionalProperties: false properties: agent_id: @@ -973,7 +973,7 @@ components: - description - parameters type: object - GetAgenticSystemSessionRequest: + GetAgentsSessionRequest: additionalProperties: false properties: turn_ids: @@ -1155,7 +1155,6 @@ components: required: - document_id - content - - mime_type - metadata type: object MemoryRetrievalStep: @@ -2357,77 +2356,77 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-09-11 16:05:23.016090" + \ draft and subject to change.\n Generated at 2024-09-17 12:55:45.538053" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema openapi: 3.1.0 paths: - /agentic_system/create: + /agents/create: post: parameters: [] requestBody: content: application/json: schema: - $ref: '#/components/schemas/CreateAgenticSystemRequest' + $ref: '#/components/schemas/CreateAgentRequest' required: true responses: '200': content: application/json: schema: - $ref: '#/components/schemas/AgenticSystemCreateResponse' + $ref: '#/components/schemas/AgentCreateResponse' description: OK tags: - - AgenticSystem - /agentic_system/delete: + - Agents + /agents/delete: post: parameters: [] requestBody: content: application/json: schema: - $ref: '#/components/schemas/DeleteAgenticSystemRequest' + $ref: '#/components/schemas/DeleteAgentsRequest' required: true responses: '200': description: OK tags: - - AgenticSystem - /agentic_system/session/create: + - Agents + /agents/session/create: post: parameters: [] requestBody: content: application/json: schema: - $ref: '#/components/schemas/CreateAgenticSystemSessionRequest' + $ref: '#/components/schemas/CreateAgentSessionRequest' required: true responses: '200': content: application/json: schema: - $ref: '#/components/schemas/AgenticSystemSessionCreateResponse' + $ref: '#/components/schemas/AgentSessionCreateResponse' description: OK tags: - - AgenticSystem - /agentic_system/session/delete: + - Agents + /agents/session/delete: post: parameters: [] requestBody: content: application/json: schema: - $ref: '#/components/schemas/DeleteAgenticSystemSessionRequest' + $ref: '#/components/schemas/DeleteAgentsSessionRequest' required: true responses: '200': description: OK tags: - - AgenticSystem - /agentic_system/session/get: + - Agents + /agents/session/get: post: parameters: - in: query @@ -2444,7 +2443,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/GetAgenticSystemSessionRequest' + $ref: '#/components/schemas/GetAgentsSessionRequest' required: true responses: '200': @@ -2454,8 +2453,8 @@ paths: $ref: '#/components/schemas/Session' description: OK tags: - - AgenticSystem - /agentic_system/step/get: + - Agents + /agents/step/get: get: parameters: - in: query @@ -2478,29 +2477,29 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/AgenticSystemStepResponse' + $ref: '#/components/schemas/AgentStepResponse' description: OK tags: - - AgenticSystem - /agentic_system/turn/create: + - Agents + /agents/turn/create: post: parameters: [] requestBody: content: application/json: schema: - $ref: '#/components/schemas/CreateAgenticSystemTurnRequest' + $ref: '#/components/schemas/CreateAgentTurnRequest' required: true responses: '200': content: - text/event-stream: + application/json: schema: - $ref: '#/components/schemas/AgenticSystemTurnResponseStreamChunk' + $ref: '#/components/schemas/AgentTurnResponseStreamChunk' description: OK tags: - - AgenticSystem - /agentic_system/turn/get: + - Agents + /agents/turn/get: get: parameters: - in: query @@ -2521,7 +2520,7 @@ paths: $ref: '#/components/schemas/Turn' description: OK tags: - - AgenticSystem + - Agents /batch_inference/chat_completion: post: parameters: [] @@ -3145,17 +3144,17 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: SyntheticDataGeneration -- name: Datasets -- name: Evaluations +- name: Agents - name: Safety -- name: Inference +- name: SyntheticDataGeneration - name: Telemetry -- name: PostTraining -- name: Memory +- name: Datasets - name: RewardScoring +- name: Evaluations +- name: PostTraining +- name: Inference - name: BatchInference -- name: AgenticSystem +- name: Memory - description: name: BuiltinTool - description: name: WolframAlphaToolDefinition -- description: - name: CreateAgenticSystemRequest -- description: - name: AgenticSystemCreateResponse -- description: - name: CreateAgenticSystemSessionRequest -- description: - name: AgenticSystemSessionCreateResponse + name: AgentSessionCreateResponse - description: name: Attachment -- description: - name: CreateAgenticSystemTurnRequest + name: CreateAgentTurnRequest - description: 'Streamed agent execution response. - ' - name: AgenticSystemTurnResponseEvent -- description: ' + name: AgentTurnResponseEvent +- description: - name: AgenticSystemTurnResponseStepCompletePayload -- description: - name: AgenticSystemTurnResponseStepProgressPayload -- description: - name: AgenticSystemTurnResponseStepStartPayload -- description: - name: AgenticSystemTurnResponseStreamChunk -- description: - name: AgenticSystemTurnResponseTurnCompletePayload -- description: - name: AgenticSystemTurnResponseTurnStartPayload + name: AgentTurnResponseTurnStartPayload - description: name: InferenceStep - description: name: MemoryBank -- description: - name: DeleteAgenticSystemRequest -- description: - name: DeleteAgenticSystemSessionRequest + name: DeleteAgentsSessionRequest - description: name: DeleteDatasetRequest @@ -3397,17 +3395,17 @@ tags: - description: name: EvaluateTextGenerationRequest -- description: - name: GetAgenticSystemSessionRequest + name: GetAgentsSessionRequest - description: 'A single session of an interaction with an Agentic System. ' name: Session -- description: - name: AgenticSystemStepResponse + name: AgentStepResponse - description: name: GetDocumentsRequest @@ -3552,7 +3550,7 @@ tags: x-tagGroups: - name: Operations tags: - - AgenticSystem + - Agents - BatchInference - Datasets - Evaluations @@ -3566,16 +3564,16 @@ x-tagGroups: - name: Types tags: - AgentConfig - - AgenticSystemCreateResponse - - AgenticSystemSessionCreateResponse - - AgenticSystemStepResponse - - AgenticSystemTurnResponseEvent - - AgenticSystemTurnResponseStepCompletePayload - - AgenticSystemTurnResponseStepProgressPayload - - AgenticSystemTurnResponseStepStartPayload - - AgenticSystemTurnResponseStreamChunk - - AgenticSystemTurnResponseTurnCompletePayload - - AgenticSystemTurnResponseTurnStartPayload + - AgentCreateResponse + - AgentSessionCreateResponse + - AgentStepResponse + - AgentTurnResponseEvent + - AgentTurnResponseStepCompletePayload + - AgentTurnResponseStepProgressPayload + - AgentTurnResponseStepStartPayload + - AgentTurnResponseStreamChunk + - AgentTurnResponseTurnCompletePayload + - AgentTurnResponseTurnStartPayload - Attachment - BatchChatCompletionRequest - BatchChatCompletionResponse @@ -3596,14 +3594,14 @@ x-tagGroups: - CompletionRequest - CompletionResponse - CompletionResponseStreamChunk - - CreateAgenticSystemRequest - - CreateAgenticSystemSessionRequest - - CreateAgenticSystemTurnRequest + - CreateAgentRequest + - CreateAgentSessionRequest + - CreateAgentTurnRequest - CreateDatasetRequest - CreateMemoryBankRequest - DPOAlignmentConfig - - DeleteAgenticSystemRequest - - DeleteAgenticSystemSessionRequest + - DeleteAgentsRequest + - DeleteAgentsSessionRequest - DeleteDatasetRequest - DeleteDocumentsRequest - DialogGenerations @@ -3620,7 +3618,7 @@ x-tagGroups: - EvaluationJobStatusResponse - FinetuningAlgorithm - FunctionCallToolDefinition - - GetAgenticSystemSessionRequest + - GetAgentsSessionRequest - GetDocumentsRequest - InferenceStep - InsertDocumentsRequest diff --git a/rfcs/RFC-0001-llama-stack.md b/rfcs/RFC-0001-llama-stack.md index 805e8cd84..137b15d11 100644 --- a/rfcs/RFC-0001-llama-stack.md +++ b/rfcs/RFC-0001-llama-stack.md @@ -1,19 +1,19 @@ # The Llama Stack API **Authors:** -* Meta: @raghotham, @ashwinb, @hjshah, @jspisak +* Meta: @raghotham, @ashwinb, @hjshah, @jspisak ## Summary As part of the Llama 3.1 release, Meta is releasing an RFC for ‘Llama Stack’, a comprehensive set of interfaces / API for ML developers building on top of Llama foundation models. We are looking for feedback on where the API can be improved, any corner cases we may have missed and your general thoughts on how useful this will be. Ultimately, our hope is to create a standard for working with Llama models in order to simplify the developer experience and foster innovation across the Llama ecosystem. ## Motivation -Llama models were always intended to work as part of an overall system that can orchestrate several components, including calling external tools. Our vision is to go beyond the foundation models and give developers access to a broader system that gives them the flexibility to design and create custom offerings that align with their vision. This thinking started last year when we first introduced a system-level safety model. Meta has continued to release new components for orchestration at the system level and, most recently in Llama 3.1, we’ve introduced the Llama Guard 3 safety model that is multilingual, a prompt injection filter, Prompt Guard and refreshed v3 of our CyberSec Evals. We are also releasing a reference implementation of an agentic system to demonstrate how all the pieces fit together. +Llama models were always intended to work as part of an overall system that can orchestrate several components, including calling external tools. Our vision is to go beyond the foundation models and give developers access to a broader system that gives them the flexibility to design and create custom offerings that align with their vision. This thinking started last year when we first introduced a system-level safety model. Meta has continued to release new components for orchestration at the system level and, most recently in Llama 3.1, we’ve introduced the Llama Guard 3 safety model that is multilingual, a prompt injection filter, Prompt Guard and refreshed v3 of our CyberSec Evals. We are also releasing a reference implementation of an agentic system to demonstrate how all the pieces fit together. -While building the reference implementation, we realized that having a clean and consistent way to interface between components could be valuable not only for us but for anyone leveraging Llama models and other components as part of their system. We’ve also heard from the community as they face a similar challenge as components exist with overlapping functionality and there are incompatible interfaces and yet don't cover the end-to-end model life cycle. +While building the reference implementation, we realized that having a clean and consistent way to interface between components could be valuable not only for us but for anyone leveraging Llama models and other components as part of their system. We’ve also heard from the community as they face a similar challenge as components exist with overlapping functionality and there are incompatible interfaces and yet don't cover the end-to-end model life cycle. With these motivations, we engaged folks in industry, startups, and the broader developer community to help better define the interfaces of these components. We’re releasing this Llama Stack RFC as a set of standardized and opinionated interfaces for how to surface canonical toolchain components (like inference, fine-tuning, evals, synthetic data generation) and agentic applications to ML developers. Our hope is to have these become well adopted across the ecosystem, which should help with easier interoperability. We would like for builders of multiple components to provide implementations to these standard APIs so that there can be vertically integrated “distributions” of the Llama Stack that can work out of the box easily. -We welcome feedback and ways to improve the proposal. We’re excited to grow the ecosystem around Llama and lower barriers for both developers and platform providers. +We welcome feedback and ways to improve the proposal. We’re excited to grow the ecosystem around Llama and lower barriers for both developers and platform providers. ## Design decisions Meta releases weights of both the pretrained and instruction fine-tuned Llama models to support several use cases. These weights can be improved - fine tuned and aligned - with curated datasets to then be deployed for inference to support specific applications. The curated datasets can be produced manually by humans or synthetically by other models or by leveraging human feedback by collecting usage data of the application itself. This results in a continuous improvement cycle where the model gets better over time. This is the model life cycle. @@ -42,8 +42,8 @@ Note that as of today, in the OSS world, such a “loop” is often coded explic **Let's consider an example:** 1. The user asks the system "Who played the NBA finals last year?" -1. The model "understands" that this question needs to be answered using web search. It answers this abstractly with a message of the form "Please call the search tool for me with the query: 'List finalist teams for NBA in the last year' ". Note that the model by itself does not call the tool (of course!) -1. The executor consults the set of tool implementations which have been configured by the developer to find an implementation for the "search tool". If it does not find it, it returns an error to the model. Otherwise, it executes this tool and returns the result of this tool back to the model. +1. The model "understands" that this question needs to be answered using web search. It answers this abstractly with a message of the form "Please call the search tool for me with the query: 'List finalist teams for NBA in the last year' ". Note that the model by itself does not call the tool (of course!) +1. The executor consults the set of tool implementations which have been configured by the developer to find an implementation for the "search tool". If it does not find it, it returns an error to the model. Otherwise, it executes this tool and returns the result of this tool back to the model. 1. The model reasons once again (using all the messages above) and decides to send a final response "In 2023, Denver Nuggets played against the Miami Heat in the NBA finals." to the executor 1. The executor returns the response directly to the user (since there is no tool call to be executed.) @@ -65,7 +65,7 @@ We define the Llama Stack as a layer cake shown below. -The API is defined in the [YAML](RFC-0001-llama-stack-assets/llama-stack-spec.yaml) and [HTML](RFC-0001-llama-stack-assets/llama-stack-spec.html) files. These files were generated using the Pydantic definitions in (api/datatypes.py and api/endpoints.py) files that are in the llama-models, llama-toolchain, and llama-agentic-system repositories. +The API is defined in the [YAML](RFC-0001-llama-stack-assets/llama-stack-spec.yaml) and [HTML](RFC-0001-llama-stack-assets/llama-stack-spec.html) files. These files were generated using the Pydantic definitions in (api/datatypes.py and api/endpoints.py) files that are in the llama-models, llama-stack, and llama-agentic-system repositories. @@ -73,14 +73,14 @@ The API is defined in the [YAML](RFC-0001-llama-stack-assets/llama-stack-spec.ya ## Sample implementations -To prove out the API, we implemented a handful of use cases to make things more concrete. The [llama-agentic-system](https://github.com/meta-llama/llama-agentic-system) repository contains [6 different examples](https://github.com/meta-llama/llama-agentic-system/tree/main/examples/scripts) ranging from very basic to a multi turn agent. +To prove out the API, we implemented a handful of use cases to make things more concrete. The [llama-agentic-system](https://github.com/meta-llama/llama-agentic-system) repository contains [6 different examples](https://github.com/meta-llama/llama-agentic-system/tree/main/examples/scripts) ranging from very basic to a multi turn agent. -There is also a sample inference endpoint implementation in the [llama-toolchain](https://github.com/meta-llama/llama-toolchain/blob/main/llama_toolchain/inference/server.py) repository. +There is also a sample inference endpoint implementation in the [llama-stack](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/inference/server.py) repository. ## Limitations The reference implementation for Llama Stack APIs to date only includes sample implementations using the inference API. We are planning to flesh out the design of Llama Stack Distributions (distros) by combining capabilities from different providers into a single vertically integrated stack. We plan to implement other APIs and, of course, we’d love contributions!! -Thank you in advance for your feedback, support and contributions to make this a better API. +Thank you in advance for your feedback, support and contributions to make this a better API. Cheers! diff --git a/rfcs/openapi_generator/README.md b/rfcs/openapi_generator/README.md index 023486534..9d407905d 100644 --- a/rfcs/openapi_generator/README.md +++ b/rfcs/openapi_generator/README.md @@ -1,4 +1,4 @@ -The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_toolchain/[]/api/endpoints.py` using the `generate.py` utility. +The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/[]/api/endpoints.py` using the `generate.py` utility. Please install the following packages before running the script: diff --git a/rfcs/openapi_generator/generate.py b/rfcs/openapi_generator/generate.py index 279389a47..0eda7282b 100644 --- a/rfcs/openapi_generator/generate.py +++ b/rfcs/openapi_generator/generate.py @@ -31,7 +31,7 @@ from .pyopenapi.utility import Specification schema_utils.json_schema_type = json_schema_type -from llama_toolchain.stack import LlamaStack +from llama_stack.apis.stack import LlamaStack # TODO: this should be fixed in the generator itself so it reads appropriate annotations diff --git a/setup.py b/setup.py index d0cacb22d..430faa5a1 100644 --- a/setup.py +++ b/setup.py @@ -15,20 +15,20 @@ def read_requirements(): setup( - name="llama_toolchain", + name="llama_stack", version="0.0.17", author="Meta Llama", author_email="llama-oss@meta.com", - description="Llama toolchain", + description="Llama Stack", entry_points={ "console_scripts": [ - "llama = llama_toolchain.cli.llama:main", - "install-wheel-from-presigned = llama_toolchain.cli.scripts.run:install_wheel_from_presigned", + "llama = llama_stack.cli.llama:main", + "install-wheel-from-presigned = llama_stack.cli.scripts.run:install_wheel_from_presigned", ] }, long_description=open("README.md").read(), long_description_content_type="text/markdown", - url="https://github.com/meta-llama/llama-toolchain", + url="https://github.com/meta-llama/llama-stack", packages=find_packages(), classifiers=[], python_requires=">=3.10", diff --git a/tests/example_custom_tool.py b/tests/example_custom_tool.py index ec338982e..f03f18e39 100644 --- a/tests/example_custom_tool.py +++ b/tests/example_custom_tool.py @@ -7,7 +7,7 @@ from typing import Dict from llama_models.llama3.api.datatypes import ToolParamDefinition -from llama_toolchain.tools.custom.datatypes import SingleMessageCustomTool +from llama_stack.tools.custom.datatypes import SingleMessageCustomTool class GetBoilingPointTool(SingleMessageCustomTool): diff --git a/tests/test_e2e.py b/tests/test_e2e.py index ea0246f20..24fc651bd 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -11,12 +11,12 @@ import os import unittest -from llama_toolchain.agentic_system.event_logger import EventLogger, LogEvent -from llama_toolchain.agentic_system.utils import get_agent_system_instance +from llama_stack.agentic_system.event_logger import EventLogger, LogEvent +from llama_stack.agentic_system.utils import get_agent_system_instance from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.agentic_system.api.datatypes import StepType -from llama_toolchain.tools.custom.datatypes import CustomTool +from llama_stack.agentic_system.api.datatypes import StepType +from llama_stack.tools.custom.datatypes import CustomTool from tests.example_custom_tool import GetBoilingPointTool diff --git a/tests/test_inference.py b/tests/test_inference.py index 800046355..ba062046d 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + # Run this test using the following command: # python -m unittest tests/test_inference.py @@ -19,12 +25,12 @@ from llama_models.llama3.api.datatypes import ( UserMessage, ) -from llama_toolchain.inference.api import ( +from llama_stack.inference.api import ( ChatCompletionRequest, ChatCompletionResponseEventType, ) -from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig -from llama_toolchain.inference.meta_reference.inference import get_provider_impl +from llama_stack.inference.meta_reference.config import MetaReferenceImplConfig +from llama_stack.inference.meta_reference.inference import get_provider_impl MODEL = "Meta-Llama3.1-8B-Instruct" diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index c3cef3a10..878e52991 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + import textwrap import unittest from datetime import datetime @@ -14,12 +20,12 @@ from llama_models.llama3.api.datatypes import ( ToolResponseMessage, UserMessage, ) -from llama_toolchain.inference.api import ( +from llama_stack.inference.api import ( ChatCompletionRequest, ChatCompletionResponseEventType, ) -from llama_toolchain.inference.ollama.config import OllamaImplConfig -from llama_toolchain.inference.ollama.ollama import get_provider_impl +from llama_stack.inference.ollama.config import OllamaImplConfig +from llama_stack.inference.ollama.ollama import get_provider_impl class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): diff --git a/tests/test_prepare_messages.py b/tests/test_prepare_messages.py index 49624b04d..df3473b4c 100644 --- a/tests/test_prepare_messages.py +++ b/tests/test_prepare_messages.py @@ -1,8 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + import unittest from llama_models.llama3.api import * # noqa: F403 -from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.inference.prepare_messages import prepare_messages +from llama_stack.inference.api import * # noqa: F403 +from llama_stack.inference.prepare_messages import prepare_messages MODEL = "Meta-Llama3.1-8B-Instruct"