From 0068d059db699a37e8d9f49c7f712349a41e4bf1 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 14 Sep 2024 15:02:22 -0700 Subject: [PATCH] move distribution to yaml files --- .../cli/stack/list_distributions.py | 2 +- .../conda/local-tgi-conda-example-build.yaml | 2 +- .../distribution_registry/local-ollama.yaml | 9 +++ .../local-plus-fireworks-inference.yaml | 9 +++ .../local-plus-tgi-inference.yaml | 8 +++ .../local-plus-together-inference.yaml | 9 +++ .../distribution_registry/local.yaml | 9 +++ llama_toolchain/configs/distributions/run.py | 21 +++++++ llama_toolchain/core/distribution_registry.py | 58 ++++--------------- 9 files changed, 78 insertions(+), 49 deletions(-) create mode 100644 llama_toolchain/configs/distributions/distribution_registry/local-ollama.yaml create mode 100644 llama_toolchain/configs/distributions/distribution_registry/local-plus-fireworks-inference.yaml create mode 100644 llama_toolchain/configs/distributions/distribution_registry/local-plus-tgi-inference.yaml create mode 100644 llama_toolchain/configs/distributions/distribution_registry/local-plus-together-inference.yaml create mode 100644 llama_toolchain/configs/distributions/distribution_registry/local.yaml create mode 100644 llama_toolchain/configs/distributions/run.py diff --git a/llama_toolchain/cli/stack/list_distributions.py b/llama_toolchain/cli/stack/list_distributions.py index 557b8c33c..6a4283f74 100644 --- a/llama_toolchain/cli/stack/list_distributions.py +++ b/llama_toolchain/cli/stack/list_distributions.py @@ -40,7 +40,7 @@ class StackListDistributions(Subcommand): rows = [] for spec in available_distribution_specs(): - providers = {k.value: v for k, v in spec.providers.items()} + providers = {k: v for k, v in spec.providers.items()} rows.append( [ spec.distribution_type, diff --git a/llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml b/llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml index 48b426385..c366adda7 100644 --- a/llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml +++ b/llama_toolchain/configs/distributions/conda/local-tgi-conda-example-build.yaml @@ -1,7 +1,7 @@ name: local-tgi-conda-example distribution_spec: distribution_type: local-plus-tgi-inference - description: Use TGI for running LLM inference + description: Use TGI (local or with Hugging Face Inference Endpoints for running LLM inference. When using HF Inference Endpoints, you must provide the name of the endpoint). docker_image: null providers: inference: remote::tgi diff --git a/llama_toolchain/configs/distributions/distribution_registry/local-ollama.yaml b/llama_toolchain/configs/distributions/distribution_registry/local-ollama.yaml new file mode 100644 index 000000000..315fe5291 --- /dev/null +++ b/llama_toolchain/configs/distributions/distribution_registry/local-ollama.yaml @@ -0,0 +1,9 @@ +distribution_type: local-ollama +description: Like local, but use ollama for running LLM inference +docker_image: null +providers: + inference: remote::ollama + safety: meta-reference + agentic_system: meta-reference + memory: meta-reference-faiss + telemetry: console diff --git a/llama_toolchain/configs/distributions/distribution_registry/local-plus-fireworks-inference.yaml b/llama_toolchain/configs/distributions/distribution_registry/local-plus-fireworks-inference.yaml new file mode 100644 index 000000000..d68c0015c --- /dev/null +++ b/llama_toolchain/configs/distributions/distribution_registry/local-plus-fireworks-inference.yaml @@ -0,0 +1,9 @@ +distribution_type: local-plus-fireworks-inference +description: Use Fireworks.ai for running LLM inference +docker_image: null +providers: + inference: remote::fireworks + safety: meta-reference + agentic_system: meta-reference + memory: meta-reference-faiss + telemetry: console diff --git a/llama_toolchain/configs/distributions/distribution_registry/local-plus-tgi-inference.yaml b/llama_toolchain/configs/distributions/distribution_registry/local-plus-tgi-inference.yaml new file mode 100644 index 000000000..27caa84e2 --- /dev/null +++ b/llama_toolchain/configs/distributions/distribution_registry/local-plus-tgi-inference.yaml @@ -0,0 +1,8 @@ +distribution_type: local-plus-tgi-inference +description: Use TGI (local or with Hugging Face Inference Endpoints for running LLM inference. When using HF Inference Endpoints, you must provide the name of the endpoint). +docker_image: null +providers: + inference: remote::tgi + safety: meta-reference + agentic_system: meta-reference + memory: meta-reference-faiss diff --git a/llama_toolchain/configs/distributions/distribution_registry/local-plus-together-inference.yaml b/llama_toolchain/configs/distributions/distribution_registry/local-plus-together-inference.yaml new file mode 100644 index 000000000..507119e1e --- /dev/null +++ b/llama_toolchain/configs/distributions/distribution_registry/local-plus-together-inference.yaml @@ -0,0 +1,9 @@ +distribution_type: local-plus-together-inference +description: Use Together.ai for running LLM inference +docker_image: null +providers: + inference: remote::together + safety: meta-reference + agentic_system: meta-reference + memory: meta-reference-faiss + telemetry: console diff --git a/llama_toolchain/configs/distributions/distribution_registry/local.yaml b/llama_toolchain/configs/distributions/distribution_registry/local.yaml new file mode 100644 index 000000000..ca8aa1fe3 --- /dev/null +++ b/llama_toolchain/configs/distributions/distribution_registry/local.yaml @@ -0,0 +1,9 @@ +distribution_type: local +description: Use code from `llama_toolchain` itself to serve all llama stack APIs +docker_image: null +providers: + inference: meta-reference + memory: meta-reference-faiss + safety: meta-reference + agentic_system: meta-reference + telemetry: console diff --git a/llama_toolchain/configs/distributions/run.py b/llama_toolchain/configs/distributions/run.py new file mode 100644 index 000000000..925ab8340 --- /dev/null +++ b/llama_toolchain/configs/distributions/run.py @@ -0,0 +1,21 @@ +from llama_toolchain.core.distribution_registry import * +import json + +import fire +import yaml +from llama_toolchain.common.serialize import EnumEncoder + + +def main(): + for d in available_distribution_specs(): + file_path = "./configs/distributions/distribution_registry/{}.yaml".format( + d.distribution_type + ) + + with open(file_path, "w") as f: + to_write = json.loads(json.dumps(d.dict(), cls=EnumEncoder)) + f.write(yaml.dump(to_write, sort_keys=False)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_toolchain/core/distribution_registry.py b/llama_toolchain/core/distribution_registry.py index 5af62bfde..0c3e5a90c 100644 --- a/llama_toolchain/core/distribution_registry.py +++ b/llama_toolchain/core/distribution_registry.py @@ -5,55 +5,19 @@ # the root directory of this source tree. from functools import lru_cache +from pathlib import Path from typing import List, Optional - from .datatypes import * # noqa: F403 +import yaml -@lru_cache() +# @lru_cache() def available_distribution_specs() -> List[DistributionSpec]: - return [ - DistributionSpec( - distribution_type="local", - description="Use code from `llama_toolchain` itself to serve all llama stack APIs", - providers={ - "inference": "meta-reference", - "memory": "meta-reference-faiss", - "safety": "meta-reference", - "agentic_system": "meta-reference", - "telemetry": "console", - }, - ), - DistributionSpec( - distribution_type="local-ollama", - description="Like local, but use ollama for running LLM inference", - providers={ - "inference": remote_provider_type("ollama"), - "safety": "meta-reference", - "agentic_system": "meta-reference", - "memory": "meta-reference-faiss", - "telemetry": "console", - }, - ), - DistributionSpec( - distribution_type="local-plus-fireworks-inference", - description="Use Fireworks.ai for running LLM inference", - providers={ - "inference": remote_provider_type("fireworks"), - "safety": "meta-reference", - "agentic_system": "meta-reference", - "memory": "meta-reference-faiss", - "telemetry": "console", - }, - ), - DistributionSpec( - distribution_type="local-plus-tgi-inference", - description="Use TGI for running LLM inference", - providers={ - "inference": remote_provider_type("tgi"), - "safety": "meta-reference", - "agentic_system": "meta-reference", - "memory": "meta-reference-faiss", - }, - ), - ] + distribution_specs = [] + for p in Path("llama_toolchain/configs/distributions/distribution_registry").rglob( + "*.yaml" + ): + with open(p, "r") as f: + distribution_specs.append(DistributionSpec(**yaml.safe_load(f))) + + return distribution_specs