diff --git a/.gitmodules b/.gitmodules index f23f58cd8..611875287 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "llama_stack/providers/impls/ios/inference/executorch"] - path = llama_stack/providers/impls/ios/inference/executorch + path = llama_stack/providers/inline/ios/inference/executorch url = https://github.com/pytorch/executorch diff --git a/distributions/bedrock/compose.yaml b/distributions/bedrock/compose.yaml new file mode 100644 index 000000000..f988e33d1 --- /dev/null +++ b/distributions/bedrock/compose.yaml @@ -0,0 +1,15 @@ +services: + llamastack: + image: distribution-bedrock + volumes: + - ~/.llama:/root/.llama + - ./run.yaml:/root/llamastack-run-bedrock.yaml + ports: + - "5000:5000" + entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-bedrock.yaml" + deploy: + restart_policy: + condition: on-failure + delay: 3s + max_attempts: 5 + window: 60s diff --git a/distributions/bedrock/run.yaml b/distributions/bedrock/run.yaml new file mode 100644 index 000000000..bd9a89566 --- /dev/null +++ b/distributions/bedrock/run.yaml @@ -0,0 +1,46 @@ +version: '2' +built_at: '2024-11-01T17:40:45.325529' +image_name: local +name: bedrock +docker_image: null +conda_env: local +apis: +- shields +- agents +- models +- memory +- memory_banks +- inference +- safety +providers: + inference: + - provider_id: bedrock0 + provider_type: remote::bedrock + config: + aws_access_key_id: + aws_secret_access_key: + aws_session_token: + region_name: + memory: + - provider_id: meta0 + provider_type: meta-reference + config: {} + safety: + - provider_id: bedrock0 + provider_type: remote::bedrock + config: + aws_access_key_id: + aws_secret_access_key: + aws_session_token: + region_name: + agents: + - provider_id: meta0 + provider_type: meta-reference + config: + persistence_store: + type: sqlite + db_path: ~/.llama/runtime/kvstore.db + telemetry: + - provider_id: meta0 + provider_type: meta-reference + config: {} diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index 5a330a598..6c36475d9 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -61,49 +61,7 @@ "```\n", "For GPU inference, you need to set these environment variables for specifying local directory containing your model checkpoints, and enable GPU inference to start running docker container.\n", "$ export LLAMA_CHECKPOINT_DIR=~/.llama\n", - "$ llama stack configure llamastack-meta-reference-gpu\n", "```\n", - "Follow the prompts as part of configure.\n", - "Here is a sample output \n", - "```\n", - "$ llama stack configure llamastack-meta-reference-gpu\n", - "\n", - "Could not find ~/.conda/envs/llamastack-llamastack-meta-reference-gpu/llamastack-meta-reference-gpu-build.yaml. Trying docker image name instead...\n", - "+ podman run --network host -it -v ~/.llama/builds/docker:/app/builds llamastack-meta-reference-gpu llama stack configure ./llamastack-build.yaml --output-dir /app/builds\n", - "\n", - "Configuring API `inference`...\n", - "=== Configuring provider `meta-reference` for API inference...\n", - "Enter value for model (default: Llama3.1-8B-Instruct) (required): Llama3.2-11B-Vision-Instruct\n", - "Do you want to configure quantization? (y/n): n\n", - "Enter value for torch_seed (optional): \n", - "Enter value for max_seq_len (default: 4096) (required): \n", - "Enter value for max_batch_size (default: 1) (required): \n", - "\n", - "Configuring API `safety`...\n", - "=== Configuring provider `meta-reference` for API safety...\n", - "Do you want to configure llama_guard_shield? (y/n): n\n", - "Do you want to configure prompt_guard_shield? (y/n): n\n", - "\n", - "Configuring API `agents`...\n", - "=== Configuring provider `meta-reference` for API agents...\n", - "Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite): \n", - "\n", - "Configuring SqliteKVStoreConfig:\n", - "Enter value for namespace (optional): \n", - "Enter value for db_path (default: /root/.llama/runtime/kvstore.db) (required): \n", - "\n", - "Configuring API `memory`...\n", - "=== Configuring provider `meta-reference` for API memory...\n", - "> Please enter the supported memory bank type your provider has for memory: vector\n", - "\n", - "Configuring API `telemetry`...\n", - "=== Configuring provider `meta-reference` for API telemetry...\n", - "\n", - "> YAML configuration has been written to /app/builds/local-gpu-run.yaml.\n", - "You can now run `llama stack run local-gpu --port PORT`\n", - "YAML configuration has been written to /home/hjshah/.llama/builds/docker/local-gpu-run.yaml. You can now run `llama stack run /home/hjshah/.llama/builds/docker/local-gpu-run.yaml`\n", - "```\n", - "NOTE: For this example, we use all local meta-reference implementations and have not setup safety. \n", "\n", "5. Run the Stack Server\n", "```\n", diff --git a/docs/source/api_providers/new_api_provider.md b/docs/source/api_providers/new_api_provider.md index 6d75c38a6..868b5bec2 100644 --- a/docs/source/api_providers/new_api_provider.md +++ b/docs/source/api_providers/new_api_provider.md @@ -6,8 +6,8 @@ This guide contains references to walk you through adding a new API provider. 1. First, decide which API your provider falls into (e.g. Inference, Safety, Agents, Memory). 2. Decide whether your provider is a remote provider, or inline implmentation. A remote provider is a provider that makes a remote request to an service. An inline provider is a provider where implementation is executed locally. Checkout the examples, and follow the structure to add your own API provider. Please find the following code pointers: - - [Inference Remote Adapter](https://github.com/meta-llama/llama-stack/tree/docs/llama_stack/providers/adapters/inference) - - [Inference Inline Provider](https://github.com/meta-llama/llama-stack/tree/docs/llama_stack/providers/impls/meta_reference/inference) + - [Inference Remote Adapter](https://github.com/meta-llama/llama-stack/tree/docs/llama_stack/providers/remote/inference) + - [Inference Inline Provider](https://github.com/meta-llama/llama-stack/tree/docs/llama_stack/providers/inline/meta_reference/inference) 3. [Build a Llama Stack distribution](https://llama-stack.readthedocs.io/en/latest/distribution_dev/building_distro.html) with your API provider. 4. Test your code! diff --git a/docs/source/distribution_dev/building_distro.md b/docs/source/distribution_dev/building_distro.md index 2f1f1b752..82724c40d 100644 --- a/docs/source/distribution_dev/building_distro.md +++ b/docs/source/distribution_dev/building_distro.md @@ -1,53 +1,56 @@ # Developer Guide: Assemble a Llama Stack Distribution -> NOTE: This doc may be out-of-date. -This guide will walk you through the steps to get started with building a Llama Stack distributiom from scratch with your choice of API providers. Please see the [Getting Started Guide](./getting_started.md) if you just want the basic steps to start a Llama Stack distribution. +This guide will walk you through the steps to get started with building a Llama Stack distributiom from scratch with your choice of API providers. Please see the [Getting Started Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) if you just want the basic steps to start a Llama Stack distribution. ## Step 1. Build -In the following steps, imagine we'll be working with a `Meta-Llama3.1-8B-Instruct` model. We will name our build `8b-instruct` to help us remember the config. We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify: -- `name`: the name for our distribution (e.g. `8b-instruct`) + +### Llama Stack Build Options + +``` +llama stack build -h +``` +We will start build our distribution (in the form of a Conda environment, or Docker image). In this step, we will specify: +- `name`: the name for our distribution (e.g. `my-stack`) - `image_type`: our build image type (`conda | docker`) - `distribution_spec`: our distribution specs for specifying API providers - `description`: a short description of the configurations for the distribution - `providers`: specifies the underlying implementation for serving each API endpoint - `image_type`: `conda` | `docker` to specify whether to build the distribution in the form of Docker image or Conda environment. +After this step is complete, a file named `-build.yaml` and template file `-run.yaml` will be generated and saved at the output file path specified at the end of the command. -At the end of build command, we will generate `-build.yaml` file storing the build configurations. +::::{tab-set} +:::{tab-item} Building from Scratch -After this step is complete, a file named `-build.yaml` will be generated and saved at the output file path specified at the end of the command. - -#### Building from scratch - For a new user, we could start off with running `llama stack build` which will allow you to a interactively enter wizard where you will be prompted to enter build configurations. ``` llama stack build + +> Enter a name for your Llama Stack (e.g. my-local-stack): my-stack +> Enter the image type you want your Llama Stack to be built as (docker or conda): conda + +Llama Stack is composed of several APIs working together. Let's select +the provider types (implementations) you want to use for these APIs. + +Tip: use to see options for the providers. + +> Enter provider for API inference: meta-reference +> Enter provider for API safety: meta-reference +> Enter provider for API agents: meta-reference +> Enter provider for API memory: meta-reference +> Enter provider for API datasetio: meta-reference +> Enter provider for API scoring: meta-reference +> Enter provider for API eval: meta-reference +> Enter provider for API telemetry: meta-reference + + > (Optional) Enter a short description for your Llama Stack: + +You can now edit ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml and run `llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml` ``` +::: -Running the command above will allow you to fill in the configuration to build your Llama Stack distribution, you will see the following outputs. - -``` -> Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): 8b-instruct -> Enter the image type you want your distribution to be built with (docker or conda): conda - - Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs. -> Enter the API provider for the inference API: (default=meta-reference): meta-reference -> Enter the API provider for the safety API: (default=meta-reference): meta-reference -> Enter the API provider for the agents API: (default=meta-reference): meta-reference -> Enter the API provider for the memory API: (default=meta-reference): meta-reference -> Enter the API provider for the telemetry API: (default=meta-reference): meta-reference - - > (Optional) Enter a short description for your Llama Stack distribution: - -Build spec configuration saved at ~/.conda/envs/llamastack-my-local-llama-stack/8b-instruct-build.yaml -``` - -**Ollama (optional)** - -If you plan to use Ollama for inference, you'll need to install the server [via these instructions](https://ollama.com/download). - - -#### Building from templates +:::{tab-item} Building from a template - To build from alternative API providers, we provide distribution templates for users to get started building a distribution backed by different providers. The following command will allow you to see the available templates and their corresponding providers. @@ -59,18 +62,21 @@ llama stack build --list-templates +------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ | Template Name | Providers | Description | +------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| bedrock | { | Use Amazon Bedrock APIs. | -| | "inference": "remote::bedrock", | | -| | "memory": "meta-reference", | | +| hf-serverless | { | Like local, but use Hugging Face Inference API (serverless) for running LLM | +| | "inference": "remote::hf::serverless", | inference. | +| | "memory": "meta-reference", | See https://hf.co/docs/api-inference. | | | "safety": "meta-reference", | | | | "agents": "meta-reference", | | | | "telemetry": "meta-reference" | | | | } | | +------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| databricks | { | Use Databricks for running LLM inference | -| | "inference": "remote::databricks", | | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | +| together | { | Use Together.ai for running LLM inference | +| | "inference": "remote::together", | | +| | "memory": [ | | +| | "meta-reference", | | +| | "remote::weaviate" | | +| | ], | | +| | "safety": "remote::together", | | | | "agents": "meta-reference", | | | | "telemetry": "meta-reference" | | | | } | | @@ -88,17 +94,37 @@ llama stack build --list-templates | | "telemetry": "meta-reference" | | | | } | | +------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| hf-endpoint | { | Like local, but use Hugging Face Inference Endpoints for running LLM inference. | -| | "inference": "remote::hf::endpoint", | See https://hf.co/docs/api-endpoints. | +| databricks | { | Use Databricks for running LLM inference | +| | "inference": "remote::databricks", | | | | "memory": "meta-reference", | | | | "safety": "meta-reference", | | | | "agents": "meta-reference", | | | | "telemetry": "meta-reference" | | | | } | | +------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| hf-serverless | { | Like local, but use Hugging Face Inference API (serverless) for running LLM | -| | "inference": "remote::hf::serverless", | inference. | -| | "memory": "meta-reference", | See https://hf.co/docs/api-inference. | +| vllm | { | Like local, but use vLLM for running LLM inference | +| | "inference": "vllm", | | +| | "memory": "meta-reference", | | +| | "safety": "meta-reference", | | +| | "agents": "meta-reference", | | +| | "telemetry": "meta-reference" | | +| | } | | ++------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ +| tgi | { | Use TGI for running LLM inference | +| | "inference": "remote::tgi", | | +| | "memory": [ | | +| | "meta-reference", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": "meta-reference", | | +| | "agents": "meta-reference", | | +| | "telemetry": "meta-reference" | | +| | } | | ++------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ +| bedrock | { | Use Amazon Bedrock APIs. | +| | "inference": "remote::bedrock", | | +| | "memory": "meta-reference", | | | | "safety": "meta-reference", | | | | "agents": "meta-reference", | | | | "telemetry": "meta-reference" | | @@ -140,31 +166,8 @@ llama stack build --list-templates | | "telemetry": "meta-reference" | | | | } | | +------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| tgi | { | Use TGI for running LLM inference | -| | "inference": "remote::tgi", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| together | { | Use Together.ai for running LLM inference | -| | "inference": "remote::together", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::weaviate" | | -| | ], | | -| | "safety": "remote::together", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| vllm | { | Like local, but use vLLM for running LLM inference | -| | "inference": "vllm", | | +| hf-endpoint | { | Like local, but use Hugging Face Inference Endpoints for running LLM inference. | +| | "inference": "remote::hf::endpoint", | See https://hf.co/docs/api-endpoints. | | | "memory": "meta-reference", | | | | "safety": "meta-reference", | | | | "agents": "meta-reference", | | @@ -175,6 +178,7 @@ llama stack build --list-templates You may then pick a template to build your distribution with providers fitted to your liking. +For example, to build a distribution with TGI as the inference provider, you can run: ``` llama stack build --template tgi ``` @@ -182,15 +186,14 @@ llama stack build --template tgi ``` $ llama stack build --template tgi ... -... -Build spec configuration saved at ~/.conda/envs/llamastack-tgi/tgi-build.yaml -You may now run `llama stack configure tgi` or `llama stack configure ~/.conda/envs/llamastack-tgi/tgi-build.yaml` +You can now edit ~/.llama/distributions/llamastack-tgi/tgi-run.yaml and run `llama stack run ~/.llama/distributions/llamastack-tgi/tgi-run.yaml` ``` +::: -#### Building from config file +:::{tab-item} Building from a pre-existing build config file - In addition to templates, you may customize the build to your liking through editing config files and build from config files with the following command. -- The config file will be of contents like the ones in `llama_stack/distributions/templates/`. +- The config file will be of contents like the ones in `llama_stack/templates/*build.yaml`. ``` $ cat llama_stack/templates/ollama/build.yaml @@ -210,148 +213,111 @@ image_type: conda ``` llama stack build --config llama_stack/templates/ollama/build.yaml ``` +::: -#### How to build distribution with Docker image - +:::{tab-item} Building Docker > [!TIP] > Podman is supported as an alternative to Docker. Set `DOCKER_BINARY` to `podman` in your environment to use Podman. To build a docker image, you may start off from a template and use the `--image-type docker` flag to specify `docker` as the build image type. ``` -llama stack build --template local --image-type docker +llama stack build --template ollama --image-type docker ``` -Alternatively, you may use a config file and set `image_type` to `docker` in our `-build.yaml` file, and run `llama stack build -build.yaml`. The `-build.yaml` will be of contents like: - ``` -name: local-docker-example -distribution_spec: - description: Use code from `llama_stack` itself to serve all llama stack APIs - docker_image: null - providers: - inference: meta-reference - memory: meta-reference-faiss - safety: meta-reference - agentic_system: meta-reference - telemetry: console -image_type: docker -``` - -The following command allows you to build a Docker image with the name `` -``` -llama stack build --config -build.yaml - -Dockerfile created successfully in /tmp/tmp.I0ifS2c46A/DockerfileFROM python:3.10-slim -WORKDIR /app +$ llama stack build --template ollama --image-type docker ... +Dockerfile created successfully in /tmp/tmp.viA3a3Rdsg/DockerfileFROM python:3.10-slim ... -You can run it with: podman run -p 8000:8000 llamastack-docker-local -Build spec configuration saved at ~/.llama/distributions/docker/docker-local-build.yaml + +You can now edit ~/meta-llama/llama-stack/tmp/configs/ollama-run.yaml and run `llama stack run ~/meta-llama/llama-stack/tmp/configs/ollama-run.yaml` ``` +After this step is successful, you should be able to find the built docker image and test it with `llama stack run `. +::: -## Step 2. Configure -After our distribution is built (either in form of docker or conda environment), we will run the following command to -``` -llama stack configure [ | ] -``` -- For `conda` environments: would be the generated build spec saved from Step 1. -- For `docker` images downloaded from Dockerhub, you could also use as the argument. - - Run `docker images` to check list of available images on your machine. +:::: + + +## Step 2. Run +Now, let's start the Llama Stack Distribution Server. You will need the YAML configuration file which was written out at the end by the `llama stack build` step. ``` -$ llama stack configure tgi - -Configuring API: inference (meta-reference) -Enter value for model (existing: Meta-Llama3.1-8B-Instruct) (required): -Enter value for quantization (optional): -Enter value for torch_seed (optional): -Enter value for max_seq_len (existing: 4096) (required): -Enter value for max_batch_size (existing: 1) (required): - -Configuring API: memory (meta-reference-faiss) - -Configuring API: safety (meta-reference) -Do you want to configure llama_guard_shield? (y/n): y -Entering sub-configuration for llama_guard_shield: -Enter value for model (default: Llama-Guard-3-1B) (required): -Enter value for excluded_categories (default: []) (required): -Enter value for disable_input_check (default: False) (required): -Enter value for disable_output_check (default: False) (required): -Do you want to configure prompt_guard_shield? (y/n): y -Entering sub-configuration for prompt_guard_shield: -Enter value for model (default: Prompt-Guard-86M) (required): - -Configuring API: agentic_system (meta-reference) -Enter value for brave_search_api_key (optional): -Enter value for bing_search_api_key (optional): -Enter value for wolfram_api_key (optional): - -Configuring API: telemetry (console) - -YAML configuration has been written to ~/.llama/builds/conda/tgi-run.yaml +llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml ``` -After this step is successful, you should be able to find a run configuration spec in `~/.llama/builds/conda/tgi-run.yaml` with the following contents. You may edit this file to change the settings. - -As you can see, we did basic configuration above and configured: -- inference to run on model `Meta-Llama3.1-8B-Instruct` (obtained from `llama model list`) -- Llama Guard safety shield with model `Llama-Guard-3-1B` -- Prompt Guard safety shield with model `Prompt-Guard-86M` - -For how these configurations are stored as yaml, checkout the file printed at the end of the configuration. - -Note that all configurations as well as models are stored in `~/.llama` - - -## Step 3. Run -Now, let's start the Llama Stack Distribution Server. You will need the YAML configuration file which was written out at the end by the `llama stack configure` step. - -``` -llama stack run 8b-instruct ``` +$ llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml -You should see the Llama Stack server start and print the APIs that it is supporting +Loaded model... +Serving API datasets + GET /datasets/get + GET /datasets/list + POST /datasets/register +Serving API inspect + GET /health + GET /providers/list + GET /routes/list +Serving API inference + POST /inference/chat_completion + POST /inference/completion + POST /inference/embeddings +Serving API scoring_functions + GET /scoring_functions/get + GET /scoring_functions/list + POST /scoring_functions/register +Serving API scoring + POST /scoring/score + POST /scoring/score_batch +Serving API memory_banks + GET /memory_banks/get + GET /memory_banks/list + POST /memory_banks/register +Serving API memory + POST /memory/insert + POST /memory/query +Serving API safety + POST /safety/run_shield +Serving API eval + POST /eval/evaluate + POST /eval/evaluate_batch + POST /eval/job/cancel + GET /eval/job/result + GET /eval/job/status +Serving API shields + GET /shields/get + GET /shields/list + POST /shields/register +Serving API datasetio + GET /datasetio/get_rows_paginated +Serving API telemetry + GET /telemetry/get_trace + POST /telemetry/log_event +Serving API models + GET /models/get + GET /models/list + POST /models/register +Serving API agents + POST /agents/create + POST /agents/session/create + POST /agents/turn/create + POST /agents/delete + POST /agents/session/delete + POST /agents/session/get + POST /agents/step/get + POST /agents/turn/get -``` -$ llama stack run 8b-instruct - -> initializing model parallel with size 1 -> initializing ddp with size 1 -> initializing pipeline with size 1 -Loaded in 19.28 seconds -NCCL version 2.20.5+cuda12.4 -Finished model load YES READY -Serving POST /inference/batch_chat_completion -Serving POST /inference/batch_completion -Serving POST /inference/chat_completion -Serving POST /inference/completion -Serving POST /safety/run_shield -Serving POST /agentic_system/memory_bank/attach -Serving POST /agentic_system/create -Serving POST /agentic_system/session/create -Serving POST /agentic_system/turn/create -Serving POST /agentic_system/delete -Serving POST /agentic_system/session/delete -Serving POST /agentic_system/memory_bank/detach -Serving POST /agentic_system/session/get -Serving POST /agentic_system/step/get -Serving POST /agentic_system/turn/get -Listening on :::5000 -INFO: Started server process [453333] +Listening on ['::', '0.0.0.0']:5000 +INFO: Started server process [2935911] INFO: Waiting for application startup. INFO: Application startup complete. -INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit) +INFO: Uvicorn running on http://['::', '0.0.0.0']:5000 (Press CTRL+C to quit) +INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK ``` -> [!NOTE] -> Configuration is in `~/.llama/builds/local/conda/tgi-run.yaml`. Feel free to increase `max_seq_len`. - > [!IMPORTANT] > The "local" distribution inference server currently only supports CUDA. It will not work on Apple Silicon machines. > [!TIP] > You might need to use the flag `--disable-ipv6` to Disable IPv6 support - -This server is running a Llama model locally. diff --git a/docs/source/getting_started/distributions/ondevice_distro/ios_sdk.md b/docs/source/getting_started/distributions/ondevice_distro/ios_sdk.md index 08885ad73..ea65ecd82 100644 --- a/docs/source/getting_started/distributions/ondevice_distro/ios_sdk.md +++ b/docs/source/getting_started/distributions/ondevice_distro/ios_sdk.md @@ -3,7 +3,7 @@ We offer both remote and on-device use of Llama Stack in Swift via two components: 1. [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift/) -2. [LocalInferenceImpl](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/impls/ios/inference) +2. [LocalInferenceImpl](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/inline/ios/inference) ```{image} ../../../../_static/remote_or_local.gif :alt: Seamlessly switching between local, on-device inference and remote hosted inference diff --git a/docs/source/getting_started/distributions/remote_hosted_distro/bedrock.md b/docs/source/getting_started/distributions/remote_hosted_distro/bedrock.md new file mode 100644 index 000000000..28691d4e3 --- /dev/null +++ b/docs/source/getting_started/distributions/remote_hosted_distro/bedrock.md @@ -0,0 +1,58 @@ +# Bedrock Distribution + +### Connect to a Llama Stack Bedrock Endpoint +- You may connect to Amazon Bedrock APIs for running LLM inference + +The `llamastack/distribution-bedrock` distribution consists of the following provider configurations. + + +| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** | +|----------------- |--------------- |---------------- |---------------- |---------------- |---------------- | +| **Provider(s)** | remote::bedrock | meta-reference | meta-reference | remote::bedrock | meta-reference | + + +### Docker: Start the Distribution (Single Node CPU) + +> [!NOTE] +> This assumes you have valid AWS credentials configured with access to Amazon Bedrock. + +``` +$ cd distributions/bedrock && docker compose up +``` + +Make sure in your `run.yaml` file, your inference provider is pointing to the correct AWS configuration. E.g. +``` +inference: + - provider_id: bedrock0 + provider_type: remote::bedrock + config: + aws_access_key_id: + aws_secret_access_key: + aws_session_token: + region_name: +``` + +### Conda llama stack run (Single Node CPU) + +```bash +llama stack build --template bedrock --image-type conda +# -- modify run.yaml with valid AWS credentials +llama stack run ./run.yaml +``` + +### (Optional) Update Model Serving Configuration + +Use `llama-stack-client models list` to check the available models served by Amazon Bedrock. + +``` +$ llama-stack-client models list ++------------------------------+------------------------------+---------------+------------+ +| identifier | llama_model | provider_id | metadata | ++==============================+==============================+===============+============+ +| Llama3.1-8B-Instruct | meta.llama3-1-8b-instruct-v1:0 | bedrock0 | {} | ++------------------------------+------------------------------+---------------+------------+ +| Llama3.1-70B-Instruct | meta.llama3-1-70b-instruct-v1:0 | bedrock0 | {} | ++------------------------------+------------------------------+---------------+------------+ +| Llama3.1-405B-Instruct | meta.llama3-1-405b-instruct-v1:0 | bedrock0 | {} | ++------------------------------+------------------------------+---------------+------------+ +``` diff --git a/docs/source/getting_started/distributions/self_hosted_distro/meta-reference-quantized-gpu.md b/docs/source/getting_started/distributions/self_hosted_distro/meta-reference-quantized-gpu.md index 0c05a13c1..afe1e3e20 100644 --- a/docs/source/getting_started/distributions/self_hosted_distro/meta-reference-quantized-gpu.md +++ b/docs/source/getting_started/distributions/self_hosted_distro/meta-reference-quantized-gpu.md @@ -9,7 +9,19 @@ The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc. -### Start the Distribution (Single Node GPU) +### Step 0. Prerequisite - Downloading Models +Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/cli_reference/download_models.html) here to download the models. + +``` +$ ls ~/.llama/checkpoints +Llama3.2-3B-Instruct:int4-qlora-eo8 +``` + +### Step 1. Start the Distribution +#### (Option 1) Start with Docker +``` +$ cd distributions/meta-reference-quantized-gpu && docker compose up +``` > [!NOTE] > This assumes you have access to GPU to start a local server with access to your GPU. @@ -19,16 +31,24 @@ The only difference vs. the `meta-reference-gpu` distribution is that it has sup > `~/.llama` should be the path containing downloaded weights of Llama models. -To download and start running a pre-built docker container, you may use the following commands: +This will download and start running a pre-built docker container. Alternatively, you may use the following commands: ``` -docker run -it -p 5000:5000 -v ~/.llama:/root/.llama \ - -v ./run.yaml:/root/my-run.yaml \ - --gpus=all \ - distribution-meta-reference-quantized-gpu \ - --yaml_config /root/my-run.yaml +docker run -it -p 5000:5000 -v ~/.llama:/root/.llama -v ./run.yaml:/root/my-run.yaml --gpus=all distribution-meta-reference-quantized-gpu --yaml_config /root/my-run.yaml ``` -### Alternative (Build and start distribution locally via conda) +#### (Option 2) Start with Conda -- You may checkout the [Getting Started](../../docs/getting_started.md) for more details on building locally via conda and starting up the distribution. +1. Install the `llama` CLI. See [CLI Reference](https://llama-stack.readthedocs.io/en/latest/cli_reference/index.html) + +2. Build the `meta-reference-quantized-gpu` distribution + +``` +$ llama stack build --template meta-reference-quantized-gpu --image-type conda +``` + +3. Start running distribution +``` +$ cd distributions/meta-reference-quantized-gpu +$ llama stack run ./run.yaml +``` diff --git a/docs/source/getting_started/distributions/self_hosted_distro/ollama.md b/docs/source/getting_started/distributions/self_hosted_distro/ollama.md index 003656e2b..0d4d90ee6 100644 --- a/docs/source/getting_started/distributions/self_hosted_distro/ollama.md +++ b/docs/source/getting_started/distributions/self_hosted_distro/ollama.md @@ -102,7 +102,7 @@ ollama pull llama3.1:70b-instruct-fp16 ``` > [!NOTE] -> Please check the [OLLAMA_SUPPORTED_MODELS](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/adapters/inference/ollama/ollama.py) for the supported Ollama models. +> Please check the [OLLAMA_SUPPORTED_MODELS](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers.remote/inference/ollama/ollama.py) for the supported Ollama models. To serve a new model with `ollama` diff --git a/docs/source/getting_started/index.md b/docs/source/getting_started/index.md index c79a6dce7..c99b5f8f9 100644 --- a/docs/source/getting_started/index.md +++ b/docs/source/getting_started/index.md @@ -386,7 +386,7 @@ ollama pull llama3.1:8b-instruct-fp16 ollama pull llama3.1:70b-instruct-fp16 ``` -> Please check the [OLLAMA_SUPPORTED_MODELS](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/adapters/inference/ollama/ollama.py) for the supported Ollama models. +> Please check the [OLLAMA_SUPPORTED_MODELS](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers.remote/inference/ollama/ollama.py) for the supported Ollama models. To serve a new model with `ollama` diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index f3615dc4b..0b74fd259 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -39,7 +39,7 @@ class RunShieldResponse(BaseModel): class ShieldStore(Protocol): - def get_shield(self, identifier: str) -> ShieldDef: ... + async def get_shield(self, identifier: str) -> ShieldDef: ... @runtime_checkable @@ -48,5 +48,5 @@ class Safety(Protocol): @webmethod(route="/safety/run_shield") async def run_shield( - self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None + self, identifier: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: ... diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 7c8e3939a..fd5634442 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -46,7 +46,7 @@ class Shields(Protocol): async def list_shields(self) -> List[ShieldDefWithProvider]: ... @webmethod(route="/shields/get", method="GET") - async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ... + async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: ... @webmethod(route="/shields/register", method="POST") async def register_shield(self, shield: ShieldDefWithProvider) -> None: ... diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 0ba39265b..94d41cfab 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -12,6 +12,10 @@ import os from functools import lru_cache from pathlib import Path +from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.utils.dynamic import instantiate_class_type + + TEMPLATES_PATH = Path(os.path.relpath(__file__)).parent.parent.parent / "templates" @@ -176,6 +180,66 @@ class StackBuild(Subcommand): return self._run_stack_build_command_from_build_config(build_config) + def _generate_run_config(self, build_config: BuildConfig, build_dir: Path) -> None: + """ + Generate a run.yaml template file for user to edit from a build.yaml file + """ + import json + + import yaml + from termcolor import cprint + + from llama_stack.distribution.build import ImageType + + apis = list(build_config.distribution_spec.providers.keys()) + run_config = StackRunConfig( + built_at=datetime.now(), + docker_image=( + build_config.name + if build_config.image_type == ImageType.docker.value + else None + ), + image_name=build_config.name, + conda_env=( + build_config.name + if build_config.image_type == ImageType.conda.value + else None + ), + apis=apis, + providers={}, + ) + # build providers dict + provider_registry = get_provider_registry() + for api in apis: + run_config.providers[api] = [] + provider_types = build_config.distribution_spec.providers[api] + if isinstance(provider_types, str): + provider_types = [provider_types] + + for i, provider_type in enumerate(provider_types): + p_spec = Provider( + provider_id=f"{provider_type}-{i}", + provider_type=provider_type, + config={}, + ) + config_type = instantiate_class_type( + provider_registry[Api(api)][provider_type].config_class + ) + p_spec.config = config_type() + run_config.providers[api].append(p_spec) + + os.makedirs(build_dir, exist_ok=True) + run_config_file = build_dir / f"{build_config.name}-run.yaml" + + with open(run_config_file, "w") as f: + to_write = json.loads(run_config.model_dump_json()) + f.write(yaml.dump(to_write, sort_keys=False)) + + cprint( + f"You can now edit {run_config_file} and run `llama stack run {run_config_file}`", + color="green", + ) + def _run_stack_build_command_from_build_config( self, build_config: BuildConfig ) -> None: @@ -183,48 +247,24 @@ class StackBuild(Subcommand): import os import yaml - from termcolor import cprint - from llama_stack.distribution.build import build_image, ImageType + from llama_stack.distribution.build import build_image from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR - from llama_stack.distribution.utils.serialize import EnumEncoder # save build.yaml spec for building same distribution again - if build_config.image_type == ImageType.docker.value: - # docker needs build file to be in the llama-stack repo dir to be able to copy over to the image - llama_stack_path = Path( - os.path.abspath(__file__) - ).parent.parent.parent.parent - build_dir = llama_stack_path / "tmp/configs/" - else: - build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}" - + build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}" os.makedirs(build_dir, exist_ok=True) build_file_path = build_dir / f"{build_config.name}-build.yaml" with open(build_file_path, "w") as f: - to_write = json.loads(json.dumps(build_config.dict(), cls=EnumEncoder)) + to_write = json.loads(build_config.model_dump_json()) f.write(yaml.dump(to_write, sort_keys=False)) return_code = build_image(build_config, build_file_path) if return_code != 0: return - configure_name = ( - build_config.name - if build_config.image_type == "conda" - else (f"llamastack-{build_config.name}") - ) - if build_config.image_type == "conda": - cprint( - f"You can now run `llama stack configure {configure_name}`", - color="green", - ) - else: - cprint( - f"You can now edit your run.yaml file and run `docker run -it -p 5000:5000 {build_config.name}`. See full command in llama-stack/distributions/", - color="green", - ) + self._generate_run_config(build_config, build_dir) def _run_template_list_cmd(self, args: argparse.Namespace) -> None: import json diff --git a/llama_stack/cli/stack/configure.py b/llama_stack/cli/stack/configure.py index 779bb90fc..7aa1bb6ed 100644 --- a/llama_stack/cli/stack/configure.py +++ b/llama_stack/cli/stack/configure.py @@ -7,8 +7,6 @@ import argparse from llama_stack.cli.subcommand import Subcommand -from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR -from llama_stack.distribution.datatypes import * # noqa: F403 class StackConfigure(Subcommand): @@ -39,123 +37,10 @@ class StackConfigure(Subcommand): ) def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None: - import json - import os - import subprocess - from pathlib import Path - - import pkg_resources - - import yaml - from termcolor import cprint - - from llama_stack.distribution.build import ImageType - from llama_stack.distribution.utils.exec import run_with_pty - - docker_image = None - - build_config_file = Path(args.config) - if build_config_file.exists(): - with open(build_config_file, "r") as f: - build_config = BuildConfig(**yaml.safe_load(f)) - self._configure_llama_distribution(build_config, args.output_dir) - return - - conda_dir = ( - Path(os.path.expanduser("~/.conda/envs")) / f"llamastack-{args.config}" - ) - output = subprocess.check_output(["bash", "-c", "conda info --json"]) - conda_envs = json.loads(output.decode("utf-8"))["envs"] - - for x in conda_envs: - if x.endswith(f"/llamastack-{args.config}"): - conda_dir = Path(x) - break - - build_config_file = Path(conda_dir) / f"{args.config}-build.yaml" - if build_config_file.exists(): - with open(build_config_file, "r") as f: - build_config = BuildConfig(**yaml.safe_load(f)) - - cprint(f"Using {build_config_file}...", "green") - self._configure_llama_distribution(build_config, args.output_dir) - return - - docker_image = args.config - builds_dir = BUILDS_BASE_DIR / ImageType.docker.value - if args.output_dir: - builds_dir = Path(output_dir) - os.makedirs(builds_dir, exist_ok=True) - - script = pkg_resources.resource_filename( - "llama_stack", "distribution/configure_container.sh" - ) - script_args = [script, docker_image, str(builds_dir)] - - return_code = run_with_pty(script_args) - if return_code != 0: - self.parser.error( - f"Failed to configure container {docker_image} with return code {return_code}. Please run `llama stack build` first. " - ) - - def _configure_llama_distribution( - self, - build_config: BuildConfig, - output_dir: Optional[str] = None, - ): - import json - import os - from pathlib import Path - - import yaml - from termcolor import cprint - - from llama_stack.distribution.configure import ( - configure_api_providers, - parse_and_maybe_upgrade_config, - ) - from llama_stack.distribution.utils.serialize import EnumEncoder - - builds_dir = BUILDS_BASE_DIR / build_config.image_type - if output_dir: - builds_dir = Path(output_dir) - os.makedirs(builds_dir, exist_ok=True) - image_name = build_config.name.replace("::", "-") - run_config_file = builds_dir / f"{image_name}-run.yaml" - - if run_config_file.exists(): - cprint( - f"Configuration already exists at `{str(run_config_file)}`. Will overwrite...", - "yellow", - attrs=["bold"], - ) - config_dict = yaml.safe_load(run_config_file.read_text()) - config = parse_and_maybe_upgrade_config(config_dict) - else: - config = StackRunConfig( - built_at=datetime.now(), - image_name=image_name, - apis=list(build_config.distribution_spec.providers.keys()), - providers={}, - ) - - config = configure_api_providers(config, build_config.distribution_spec) - - config.docker_image = ( - image_name if build_config.image_type == "docker" else None - ) - config.conda_env = image_name if build_config.image_type == "conda" else None - - with open(run_config_file, "w") as f: - to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder)) - f.write(yaml.dump(to_write, sort_keys=False)) - - cprint( - f"> YAML configuration has been written to `{run_config_file}`.", - color="blue", - ) - - cprint( - f"You can now run `llama stack run {image_name} --port PORT`", - color="green", + self.parser.error( + """ + DEPRECATED! llama stack configure has been deprecated. + Please use llama stack run --config instead. + Please see example run.yaml in /distributions folder. + """ ) diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index dd4247e4b..842703d4c 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -45,7 +45,6 @@ class StackRun(Subcommand): import pkg_resources import yaml - from termcolor import cprint from llama_stack.distribution.build import ImageType from llama_stack.distribution.configure import parse_and_maybe_upgrade_config @@ -71,14 +70,12 @@ class StackRun(Subcommand): if not config_file.exists(): self.parser.error( - f"File {str(config_file)} does not exist. Please run `llama stack build` and `llama stack configure ` to generate a run.yaml file" + f"File {str(config_file)} does not exist. Please run `llama stack build` to generate (and optionally edit) a run.yaml file" ) return - cprint(f"Using config `{config_file}`", "green") - with open(config_file, "r") as f: - config_dict = yaml.safe_load(config_file.read_text()) - config = parse_and_maybe_upgrade_config(config_dict) + config_dict = yaml.safe_load(config_file.read_text()) + config = parse_and_maybe_upgrade_config(config_dict) if config.docker_image: script = pkg_resources.resource_filename( diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index ae2b17d9e..e5ec5b4e2 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -36,7 +36,6 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")") REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR")) DOCKER_BINARY=${DOCKER_BINARY:-docker} DOCKER_OPTS=${DOCKER_OPTS:-} -REPO_CONFIGS_DIR="$REPO_DIR/tmp/configs" TEMP_DIR=$(mktemp -d) @@ -115,8 +114,6 @@ ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"] EOF -add_to_docker "ADD tmp/configs/$(basename "$build_file_path") ./llamastack-build.yaml" - printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile" cat $TEMP_DIR/Dockerfile printf "\n" @@ -138,7 +135,6 @@ set -x $DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REPO_DIR" $mounts # clean up tmp/configs -rm -rf $REPO_CONFIGS_DIR set +x echo "Success!" diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 348d8449d..760dbaf2f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -154,12 +154,12 @@ class SafetyRouter(Safety): async def run_shield( self, - shield_type: str, + identifier: str, messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: - return await self.routing_table.get_provider_impl(shield_type).run_shield( - shield_type=shield_type, + return await self.routing_table.get_provider_impl(identifier).run_shield( + identifier=identifier, messages=messages, params=params, ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 1efd02c89..bcf125bec 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -182,6 +182,12 @@ class CommonRoutingTableImpl(RoutingTable): objs = await self.dist_registry.get_all() return [obj for obj in objs if obj.type == type] + async def get_all_with_types( + self, types: List[str] + ) -> List[RoutableObjectWithProvider]: + objs = await self.dist_registry.get_all() + return [obj for obj in objs if obj.type in types] + class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> List[ModelDefWithProvider]: @@ -198,8 +204,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[ShieldDef]: return await self.get_all_with_type("shield") - async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: - return await self.get_object_by_identifier(shield_type) + async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: + return await self.get_object_by_identifier(identifier) async def register_shield(self, shield: ShieldDefWithProvider) -> None: await self.register_object(shield) @@ -207,7 +213,14 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: - return await self.get_all_with_type("memory_bank") + return await self.get_all_with_types( + [ + MemoryBankType.vector.value, + MemoryBankType.keyvalue.value, + MemoryBankType.keyword.value, + MemoryBankType.graph.value, + ] + ) async def get_memory_bank( self, identifier: str diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 2560f4070..16c0fd0e0 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -209,7 +209,8 @@ async def maybe_await(value): async def sse_generator(event_gen): try: - async for item in await event_gen: + event_gen = await event_gen + async for item in event_gen: yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: @@ -229,7 +230,6 @@ async def sse_generator(event_gen): def create_dynamic_typed_route(func: Any, method: str): - async def endpoint(request: Request, **kwargs): await start_trace(func.__name__) diff --git a/llama_stack/providers/adapters/safety/bedrock/config.py b/llama_stack/providers/adapters/safety/bedrock/config.py deleted file mode 100644 index 2a8585262..000000000 --- a/llama_stack/providers/adapters/safety/bedrock/config.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from pydantic import BaseModel, Field - - -class BedrockSafetyConfig(BaseModel): - """Configuration information for a guardrail that you want to use in the request.""" - - aws_profile: str = Field( - default="default", - description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation", - ) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 9a37a28a9..69255fc5f 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -145,11 +145,12 @@ Fully-qualified name of the module to import. The module is expected to have: class RemoteProviderConfig(BaseModel): host: str = "localhost" - port: int + port: int = 0 + protocol: str = "http" @property def url(self) -> str: - return f"http://{self.host}:{self.port}" + return f"{self.protocol}://{self.host}:{self.port}" @json_schema_type diff --git a/llama_stack/providers/adapters/__init__.py b/llama_stack/providers/inline/__init__.py similarity index 100% rename from llama_stack/providers/adapters/__init__.py rename to llama_stack/providers/inline/__init__.py diff --git a/llama_stack/providers/impls/braintrust/scoring/__init__.py b/llama_stack/providers/inline/braintrust/scoring/__init__.py similarity index 100% rename from llama_stack/providers/impls/braintrust/scoring/__init__.py rename to llama_stack/providers/inline/braintrust/scoring/__init__.py diff --git a/llama_stack/providers/impls/braintrust/scoring/braintrust.py b/llama_stack/providers/inline/braintrust/scoring/braintrust.py similarity index 98% rename from llama_stack/providers/impls/braintrust/scoring/braintrust.py rename to llama_stack/providers/inline/braintrust/scoring/braintrust.py index 826d60379..6488a63eb 100644 --- a/llama_stack/providers/impls/braintrust/scoring/braintrust.py +++ b/llama_stack/providers/inline/braintrust/scoring/braintrust.py @@ -16,7 +16,7 @@ from llama_stack.apis.datasets import * # noqa: F403 from autoevals.llm import Factuality from autoevals.ragas import AnswerCorrectness from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( +from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import ( aggregate_average, ) diff --git a/llama_stack/providers/impls/braintrust/scoring/config.py b/llama_stack/providers/inline/braintrust/scoring/config.py similarity index 100% rename from llama_stack/providers/impls/braintrust/scoring/config.py rename to llama_stack/providers/inline/braintrust/scoring/config.py diff --git a/llama_stack/providers/adapters/agents/__init__.py b/llama_stack/providers/inline/braintrust/scoring/scoring_fn/__init__.py similarity index 100% rename from llama_stack/providers/adapters/agents/__init__.py rename to llama_stack/providers/inline/braintrust/scoring/scoring_fn/__init__.py diff --git a/llama_stack/providers/adapters/inference/__init__.py b/llama_stack/providers/inline/braintrust/scoring/scoring_fn/fn_defs/__init__.py similarity index 100% rename from llama_stack/providers/adapters/inference/__init__.py rename to llama_stack/providers/inline/braintrust/scoring/scoring_fn/fn_defs/__init__.py diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/answer_correctness.py b/llama_stack/providers/inline/braintrust/scoring/scoring_fn/fn_defs/answer_correctness.py similarity index 100% rename from llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/answer_correctness.py rename to llama_stack/providers/inline/braintrust/scoring/scoring_fn/fn_defs/answer_correctness.py diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.py b/llama_stack/providers/inline/braintrust/scoring/scoring_fn/fn_defs/factuality.py similarity index 100% rename from llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.py rename to llama_stack/providers/inline/braintrust/scoring/scoring_fn/fn_defs/factuality.py diff --git a/llama_stack/providers/impls/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj similarity index 100% rename from llama_stack/providers/impls/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj rename to llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj diff --git a/llama_stack/providers/impls/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata similarity index 100% rename from llama_stack/providers/impls/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata rename to llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata diff --git a/llama_stack/providers/impls/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist similarity index 100% rename from llama_stack/providers/impls/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist rename to llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist diff --git a/llama_stack/providers/impls/ios/inference/LocalInferenceImpl/LocalInference.h b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h similarity index 100% rename from llama_stack/providers/impls/ios/inference/LocalInferenceImpl/LocalInference.h rename to llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h diff --git a/llama_stack/providers/impls/ios/inference/LocalInferenceImpl/LocalInference.swift b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift similarity index 100% rename from llama_stack/providers/impls/ios/inference/LocalInferenceImpl/LocalInference.swift rename to llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift diff --git a/llama_stack/providers/impls/ios/inference/LocalInferenceImpl/Parsing.swift b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift similarity index 100% rename from llama_stack/providers/impls/ios/inference/LocalInferenceImpl/Parsing.swift rename to llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift diff --git a/llama_stack/providers/impls/ios/inference/LocalInferenceImpl/PromptTemplate.swift b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift similarity index 100% rename from llama_stack/providers/impls/ios/inference/LocalInferenceImpl/PromptTemplate.swift rename to llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift diff --git a/llama_stack/providers/impls/ios/inference/LocalInferenceImpl/SystemPrompts.swift b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift similarity index 100% rename from llama_stack/providers/impls/ios/inference/LocalInferenceImpl/SystemPrompts.swift rename to llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift diff --git a/llama_stack/providers/impls/ios/inference/executorch b/llama_stack/providers/inline/ios/inference/executorch similarity index 100% rename from llama_stack/providers/impls/ios/inference/executorch rename to llama_stack/providers/inline/ios/inference/executorch diff --git a/llama_stack/providers/adapters/memory/__init__.py b/llama_stack/providers/inline/meta_reference/__init__.py similarity index 100% rename from llama_stack/providers/adapters/memory/__init__.py rename to llama_stack/providers/inline/meta_reference/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/agents/__init__.py b/llama_stack/providers/inline/meta_reference/agents/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/__init__.py rename to llama_stack/providers/inline/meta_reference/agents/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/inline/meta_reference/agents/agent_instance.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/agent_instance.py rename to llama_stack/providers/inline/meta_reference/agents/agent_instance.py diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/inline/meta_reference/agents/agents.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/agents.py rename to llama_stack/providers/inline/meta_reference/agents/agents.py diff --git a/llama_stack/providers/impls/meta_reference/agents/config.py b/llama_stack/providers/inline/meta_reference/agents/config.py similarity index 62% rename from llama_stack/providers/impls/meta_reference/agents/config.py rename to llama_stack/providers/inline/meta_reference/agents/config.py index 0146cb436..2770ed13c 100644 --- a/llama_stack/providers/impls/meta_reference/agents/config.py +++ b/llama_stack/providers/inline/meta_reference/agents/config.py @@ -4,10 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel +from pydantic import BaseModel, Field from llama_stack.providers.utils.kvstore import KVStoreConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class MetaReferenceAgentsImplConfig(BaseModel): - persistence_store: KVStoreConfig + persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig()) diff --git a/llama_stack/providers/impls/meta_reference/agents/persistence.py b/llama_stack/providers/inline/meta_reference/agents/persistence.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/persistence.py rename to llama_stack/providers/inline/meta_reference/agents/persistence.py diff --git a/llama_stack/providers/adapters/safety/__init__.py b/llama_stack/providers/inline/meta_reference/agents/rag/__init__.py similarity index 100% rename from llama_stack/providers/adapters/safety/__init__.py rename to llama_stack/providers/inline/meta_reference/agents/rag/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/agents/rag/context_retriever.py b/llama_stack/providers/inline/meta_reference/agents/rag/context_retriever.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/rag/context_retriever.py rename to llama_stack/providers/inline/meta_reference/agents/rag/context_retriever.py diff --git a/llama_stack/providers/impls/meta_reference/agents/safety.py b/llama_stack/providers/inline/meta_reference/agents/safety.py similarity index 83% rename from llama_stack/providers/impls/meta_reference/agents/safety.py rename to llama_stack/providers/inline/meta_reference/agents/safety.py index fb5821f6a..915ddd303 100644 --- a/llama_stack/providers/impls/meta_reference/agents/safety.py +++ b/llama_stack/providers/inline/meta_reference/agents/safety.py @@ -32,18 +32,18 @@ class ShieldRunnerMixin: self.output_shields = output_shields async def run_multiple_shields( - self, messages: List[Message], shield_types: List[str] + self, messages: List[Message], identifiers: List[str] ) -> None: responses = await asyncio.gather( *[ self.safety_api.run_shield( - shield_type=shield_type, + identifier=identifier, messages=messages, ) - for shield_type in shield_types + for identifier in identifiers ] ) - for shield_type, response in zip(shield_types, responses): + for identifier, response in zip(identifiers, responses): if not response.violation: continue @@ -52,6 +52,6 @@ class ShieldRunnerMixin: raise SafetyException(violation) elif violation.violation_level == ViolationLevel.WARN: cprint( - f"[Warn]{shield_type} raised a warning", + f"[Warn]{identifier} raised a warning", color="red", ) diff --git a/llama_stack/providers/adapters/telemetry/__init__.py b/llama_stack/providers/inline/meta_reference/agents/tests/__init__.py similarity index 100% rename from llama_stack/providers/adapters/telemetry/__init__.py rename to llama_stack/providers/inline/meta_reference/agents/tests/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/agents/tests/code_execution.py b/llama_stack/providers/inline/meta_reference/agents/tests/code_execution.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/tests/code_execution.py rename to llama_stack/providers/inline/meta_reference/agents/tests/code_execution.py diff --git a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py b/llama_stack/providers/inline/meta_reference/agents/tests/test_chat_agent.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py rename to llama_stack/providers/inline/meta_reference/agents/tests/test_chat_agent.py diff --git a/llama_stack/providers/impls/__init__.py b/llama_stack/providers/inline/meta_reference/agents/tools/__init__.py similarity index 100% rename from llama_stack/providers/impls/__init__.py rename to llama_stack/providers/inline/meta_reference/agents/tools/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/agents/tools/base.py b/llama_stack/providers/inline/meta_reference/agents/tools/base.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/tools/base.py rename to llama_stack/providers/inline/meta_reference/agents/tools/base.py diff --git a/llama_stack/providers/impls/meta_reference/agents/tools/builtin.py b/llama_stack/providers/inline/meta_reference/agents/tools/builtin.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/tools/builtin.py rename to llama_stack/providers/inline/meta_reference/agents/tools/builtin.py diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/__init__.py b/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/__init__.py similarity index 100% rename from llama_stack/providers/impls/braintrust/scoring/scoring_fn/__init__.py rename to llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_env_prefix.py b/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/code_env_prefix.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_env_prefix.py rename to llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/code_env_prefix.py diff --git a/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_execution.py b/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/code_execution.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/code_execution.py rename to llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/code_execution.py diff --git a/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py b/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py rename to llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py diff --git a/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/utils.py b/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/utils.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/utils.py rename to llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/utils.py diff --git a/llama_stack/providers/impls/meta_reference/agents/tools/safety.py b/llama_stack/providers/inline/meta_reference/agents/tools/safety.py similarity index 93% rename from llama_stack/providers/impls/meta_reference/agents/tools/safety.py rename to llama_stack/providers/inline/meta_reference/agents/tools/safety.py index fb95786d1..72530f0e6 100644 --- a/llama_stack/providers/impls/meta_reference/agents/tools/safety.py +++ b/llama_stack/providers/inline/meta_reference/agents/tools/safety.py @@ -9,7 +9,7 @@ from typing import List from llama_stack.apis.inference import Message from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin +from llama_stack.providers.inline.meta_reference.agents.safety import ShieldRunnerMixin from .builtin import BaseTool diff --git a/llama_stack/providers/impls/meta_reference/codeshield/__init__.py b/llama_stack/providers/inline/meta_reference/codeshield/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/codeshield/__init__.py rename to llama_stack/providers/inline/meta_reference/codeshield/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py b/llama_stack/providers/inline/meta_reference/codeshield/code_scanner.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py rename to llama_stack/providers/inline/meta_reference/codeshield/code_scanner.py diff --git a/llama_stack/providers/impls/meta_reference/codeshield/config.py b/llama_stack/providers/inline/meta_reference/codeshield/config.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/codeshield/config.py rename to llama_stack/providers/inline/meta_reference/codeshield/config.py diff --git a/llama_stack/providers/impls/meta_reference/datasetio/__init__.py b/llama_stack/providers/inline/meta_reference/datasetio/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/datasetio/__init__.py rename to llama_stack/providers/inline/meta_reference/datasetio/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/datasetio/config.py b/llama_stack/providers/inline/meta_reference/datasetio/config.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/datasetio/config.py rename to llama_stack/providers/inline/meta_reference/datasetio/config.py diff --git a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py b/llama_stack/providers/inline/meta_reference/datasetio/datasetio.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/datasetio/datasetio.py rename to llama_stack/providers/inline/meta_reference/datasetio/datasetio.py diff --git a/llama_stack/providers/impls/meta_reference/eval/__init__.py b/llama_stack/providers/inline/meta_reference/eval/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/eval/__init__.py rename to llama_stack/providers/inline/meta_reference/eval/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/eval/config.py b/llama_stack/providers/inline/meta_reference/eval/config.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/eval/config.py rename to llama_stack/providers/inline/meta_reference/eval/config.py diff --git a/llama_stack/providers/impls/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/eval/eval.py rename to llama_stack/providers/inline/meta_reference/eval/eval.py diff --git a/llama_stack/providers/impls/meta_reference/inference/__init__.py b/llama_stack/providers/inline/meta_reference/inference/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/inference/__init__.py rename to llama_stack/providers/inline/meta_reference/inference/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/inference/config.py b/llama_stack/providers/inline/meta_reference/inference/config.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/inference/config.py rename to llama_stack/providers/inline/meta_reference/inference/config.py diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/inline/meta_reference/inference/generation.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/inference/generation.py rename to llama_stack/providers/inline/meta_reference/inference/generation.py diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/inline/meta_reference/inference/inference.py similarity index 92% rename from llama_stack/providers/impls/meta_reference/inference/inference.py rename to llama_stack/providers/inline/meta_reference/inference/inference.py index 5588be6c0..b643ac238 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/inline/meta_reference/inference/inference.py @@ -14,6 +14,11 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate +from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_image_media_to_url, + request_has_media, +) + from .config import MetaReferenceInferenceConfig from .generation import Llama from .model_parallel import LlamaModelParallelGenerator @@ -87,6 +92,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): logprobs=logprobs, ) self.check_model(request) + request = await request_with_localized_media(request) if request.stream: return self._stream_completion(request) @@ -211,6 +217,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): logprobs=logprobs, ) self.check_model(request) + request = await request_with_localized_media(request) if self.config.create_distributed_process_group: if SEMAPHORE.locked(): @@ -388,3 +395,31 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() + + +async def request_with_localized_media( + request: Union[ChatCompletionRequest, CompletionRequest], +) -> Union[ChatCompletionRequest, CompletionRequest]: + if not request_has_media(request): + return request + + async def _convert_single_content(content): + if isinstance(content, ImageMedia): + url = await convert_image_media_to_url(content, download=True) + return ImageMedia(image=URL(uri=url)) + else: + return content + + async def _convert_content(content): + if isinstance(content, list): + return [await _convert_single_content(c) for c in content] + else: + return await _convert_single_content(content) + + if isinstance(request, ChatCompletionRequest): + for m in request.messages: + m.content = await _convert_content(m.content) + else: + request.content = await _convert_content(request.content) + + return request diff --git a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py b/llama_stack/providers/inline/meta_reference/inference/model_parallel.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/inference/model_parallel.py rename to llama_stack/providers/inline/meta_reference/inference/model_parallel.py diff --git a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py b/llama_stack/providers/inline/meta_reference/inference/parallel_utils.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/inference/parallel_utils.py rename to llama_stack/providers/inline/meta_reference/inference/parallel_utils.py diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/__init__.py b/llama_stack/providers/inline/meta_reference/inference/quantization/__init__.py similarity index 100% rename from llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/__init__.py rename to llama_stack/providers/inline/meta_reference/inference/quantization/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/fp8_impls.py b/llama_stack/providers/inline/meta_reference/inference/quantization/fp8_impls.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/inference/quantization/fp8_impls.py rename to llama_stack/providers/inline/meta_reference/inference/quantization/fp8_impls.py diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/fp8_txest_disabled.py b/llama_stack/providers/inline/meta_reference/inference/quantization/fp8_txest_disabled.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/inference/quantization/fp8_txest_disabled.py rename to llama_stack/providers/inline/meta_reference/inference/quantization/fp8_txest_disabled.py diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/hadamard_utils.py b/llama_stack/providers/inline/meta_reference/inference/quantization/hadamard_utils.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/inference/quantization/hadamard_utils.py rename to llama_stack/providers/inline/meta_reference/inference/quantization/hadamard_utils.py diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py b/llama_stack/providers/inline/meta_reference/inference/quantization/loader.py similarity index 99% rename from llama_stack/providers/impls/meta_reference/inference/quantization/loader.py rename to llama_stack/providers/inline/meta_reference/inference/quantization/loader.py index 9f30354bb..3492ab043 100644 --- a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py +++ b/llama_stack/providers/inline/meta_reference/inference/quantization/loader.py @@ -27,7 +27,7 @@ from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from llama_stack.apis.inference import QuantizationType -from llama_stack.providers.impls.meta_reference.inference.config import ( +from llama_stack.providers.inline.meta_reference.inference.config import ( MetaReferenceQuantizedInferenceConfig, ) diff --git a/llama_stack/providers/impls/meta_reference/__init__.py b/llama_stack/providers/inline/meta_reference/inference/quantization/scripts/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/__init__.py rename to llama_stack/providers/inline/meta_reference/inference/quantization/scripts/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/scripts/build_conda.sh b/llama_stack/providers/inline/meta_reference/inference/quantization/scripts/build_conda.sh similarity index 100% rename from llama_stack/providers/impls/meta_reference/inference/quantization/scripts/build_conda.sh rename to llama_stack/providers/inline/meta_reference/inference/quantization/scripts/build_conda.sh diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/scripts/quantize_checkpoint.py b/llama_stack/providers/inline/meta_reference/inference/quantization/scripts/quantize_checkpoint.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/inference/quantization/scripts/quantize_checkpoint.py rename to llama_stack/providers/inline/meta_reference/inference/quantization/scripts/quantize_checkpoint.py diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/scripts/run_quantize_checkpoint.sh b/llama_stack/providers/inline/meta_reference/inference/quantization/scripts/run_quantize_checkpoint.sh similarity index 100% rename from llama_stack/providers/impls/meta_reference/inference/quantization/scripts/run_quantize_checkpoint.sh rename to llama_stack/providers/inline/meta_reference/inference/quantization/scripts/run_quantize_checkpoint.sh diff --git a/llama_stack/providers/impls/meta_reference/memory/__init__.py b/llama_stack/providers/inline/meta_reference/memory/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/memory/__init__.py rename to llama_stack/providers/inline/meta_reference/memory/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/memory/config.py b/llama_stack/providers/inline/meta_reference/memory/config.py new file mode 100644 index 000000000..41970b05f --- /dev/null +++ b/llama_stack/providers/inline/meta_reference/memory/config.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel + +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) + + +@json_schema_type +class FaissImplConfig(BaseModel): + kvstore: KVStoreConfig = SqliteKVStoreConfig( + db_path=(RUNTIME_BASE_DIR / "faiss_store.db").as_posix() + ) # Uses SQLite config specific to FAISS storage diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/inline/meta_reference/memory/faiss.py similarity index 77% rename from llama_stack/providers/impls/meta_reference/memory/faiss.py rename to llama_stack/providers/inline/meta_reference/memory/faiss.py index 02829f7be..4bd5fd5a7 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/inline/meta_reference/memory/faiss.py @@ -16,6 +16,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.vector_store import ( ALL_MINILM_L6_V2_DIMENSION, @@ -28,6 +29,8 @@ from .config import FaissImplConfig logger = logging.getLogger(__name__) +MEMORY_BANKS_PREFIX = "memory_banks:" + class FaissIndex(EmbeddingIndex): id_by_index: Dict[int, str] @@ -69,10 +72,25 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): def __init__(self, config: FaissImplConfig) -> None: self.config = config self.cache = {} + self.kvstore = None - async def initialize(self) -> None: ... + async def initialize(self) -> None: + self.kvstore = await kvstore_impl(self.config.kvstore) + # Load existing banks from kvstore + start_key = MEMORY_BANKS_PREFIX + end_key = f"{MEMORY_BANKS_PREFIX}\xff" + stored_banks = await self.kvstore.range(start_key, end_key) - async def shutdown(self) -> None: ... + for bank_data in stored_banks: + bank = VectorMemoryBankDef.model_validate_json(bank_data) + index = BankWithIndex( + bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) + ) + self.cache[bank.identifier] = index + + async def shutdown(self) -> None: + # Cleanup if needed + pass async def register_memory_bank( self, @@ -82,6 +100,14 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): memory_bank.type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.type}" + # Store in kvstore + key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}" + await self.kvstore.set( + key=key, + value=memory_bank.json(), + ) + + # Store in cache index = BankWithIndex( bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) ) diff --git a/llama_stack/providers/inline/meta_reference/memory/tests/test_faiss.py b/llama_stack/providers/inline/meta_reference/memory/tests/test_faiss.py new file mode 100644 index 000000000..7b944319f --- /dev/null +++ b/llama_stack/providers/inline/meta_reference/memory/tests/test_faiss.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import tempfile + +import pytest +from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef +from llama_stack.providers.inline.meta_reference.memory.config import FaissImplConfig + +from llama_stack.providers.inline.meta_reference.memory.faiss import FaissMemoryImpl +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + + +class TestFaissMemoryImpl: + @pytest.fixture + def faiss_impl(self): + # Create a temporary SQLite database file + temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name)) + return FaissMemoryImpl(config) + + @pytest.mark.asyncio + async def test_initialize(self, faiss_impl): + # Test empty initialization + await faiss_impl.initialize() + assert len(faiss_impl.cache) == 0 + + # Test initialization with existing banks + bank = VectorMemoryBankDef( + identifier="test_bank", + type=MemoryBankType.vector.value, + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ) + + # Register a bank and reinitialize to test loading + await faiss_impl.register_memory_bank(bank) + + # Create new instance to test initialization with existing data + new_impl = FaissMemoryImpl(faiss_impl.config) + await new_impl.initialize() + + assert len(new_impl.cache) == 1 + assert "test_bank" in new_impl.cache + + @pytest.mark.asyncio + async def test_register_memory_bank(self, faiss_impl): + bank = VectorMemoryBankDef( + identifier="test_bank", + type=MemoryBankType.vector.value, + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ) + + await faiss_impl.initialize() + await faiss_impl.register_memory_bank(bank) + + assert "test_bank" in faiss_impl.cache + assert faiss_impl.cache["test_bank"].bank == bank + + # Verify persistence + new_impl = FaissMemoryImpl(faiss_impl.config) + await new_impl.initialize() + assert "test_bank" in new_impl.cache + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/llama_stack/providers/impls/meta_reference/safety/__init__.py b/llama_stack/providers/inline/meta_reference/safety/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/safety/__init__.py rename to llama_stack/providers/inline/meta_reference/safety/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/safety/base.py b/llama_stack/providers/inline/meta_reference/safety/base.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/safety/base.py rename to llama_stack/providers/inline/meta_reference/safety/base.py diff --git a/llama_stack/providers/impls/meta_reference/safety/config.py b/llama_stack/providers/inline/meta_reference/safety/config.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/safety/config.py rename to llama_stack/providers/inline/meta_reference/safety/config.py diff --git a/llama_stack/providers/impls/meta_reference/safety/llama_guard.py b/llama_stack/providers/inline/meta_reference/safety/llama_guard.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/safety/llama_guard.py rename to llama_stack/providers/inline/meta_reference/safety/llama_guard.py diff --git a/llama_stack/providers/impls/meta_reference/safety/prompt_guard.py b/llama_stack/providers/inline/meta_reference/safety/prompt_guard.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/safety/prompt_guard.py rename to llama_stack/providers/inline/meta_reference/safety/prompt_guard.py diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/inline/meta_reference/safety/safety.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/safety/safety.py rename to llama_stack/providers/inline/meta_reference/safety/safety.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/__init__.py b/llama_stack/providers/inline/meta_reference/scoring/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/scoring/__init__.py rename to llama_stack/providers/inline/meta_reference/scoring/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/config.py b/llama_stack/providers/inline/meta_reference/scoring/config.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/scoring/config.py rename to llama_stack/providers/inline/meta_reference/scoring/config.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/inline/meta_reference/scoring/scoring.py similarity index 94% rename from llama_stack/providers/impls/meta_reference/scoring/scoring.py rename to llama_stack/providers/inline/meta_reference/scoring/scoring.py index 41b24a512..709b2f0c6 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring.py @@ -13,15 +13,15 @@ from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.inference.inference import Inference from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import ( +from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.equality_scoring_fn import ( EqualityScoringFn, ) -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import ( +from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import ( LlmAsJudgeScoringFn, ) -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import ( +from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import ( SubsetOfScoringFn, ) diff --git a/llama_stack/providers/impls/meta_reference/agents/rag/__init__.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/rag/__init__.py rename to llama_stack/providers/inline/meta_reference/scoring/scoring_fn/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py rename to llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/common.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/common.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/scoring/scoring_fn/common.py rename to llama_stack/providers/inline/meta_reference/scoring/scoring_fn/common.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py similarity index 85% rename from llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py rename to llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py index 556436286..2a0cd0578 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py @@ -4,18 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import ( +from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import ( BaseScoringFn, ) from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( +from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import ( aggregate_accuracy, ) -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import ( +from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.equality import ( equality, ) diff --git a/llama_stack/providers/impls/meta_reference/agents/tests/__init__.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/tests/__init__.py rename to llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py rename to llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/equality.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py rename to llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py rename to llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py similarity index 90% rename from llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py rename to llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py index 5a5ce2550..84dd28fd7 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from llama_stack.apis.inference.inference import Inference -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import ( +from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import ( BaseScoringFn, ) from llama_stack.apis.scoring_functions import * # noqa: F401, F403 @@ -12,10 +12,10 @@ from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 import re -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( +from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import ( aggregate_average, ) -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import ( +from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import ( llm_as_judge_8b_correctness, ) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py similarity index 83% rename from llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py rename to llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py index fcef2ead7..f42964c1f 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py @@ -4,17 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import ( +from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import ( BaseScoringFn, ) from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( +from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import ( aggregate_accuracy, ) -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import ( +from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.subset_of import ( subset_of, ) diff --git a/llama_stack/providers/impls/meta_reference/telemetry/__init__.py b/llama_stack/providers/inline/meta_reference/telemetry/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/telemetry/__init__.py rename to llama_stack/providers/inline/meta_reference/telemetry/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/telemetry/config.py b/llama_stack/providers/inline/meta_reference/telemetry/config.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/telemetry/config.py rename to llama_stack/providers/inline/meta_reference/telemetry/config.py diff --git a/llama_stack/providers/impls/meta_reference/telemetry/console.py b/llama_stack/providers/inline/meta_reference/telemetry/console.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/telemetry/console.py rename to llama_stack/providers/inline/meta_reference/telemetry/console.py diff --git a/llama_stack/providers/impls/vllm/__init__.py b/llama_stack/providers/inline/vllm/__init__.py similarity index 100% rename from llama_stack/providers/impls/vllm/__init__.py rename to llama_stack/providers/inline/vllm/__init__.py diff --git a/llama_stack/providers/impls/vllm/config.py b/llama_stack/providers/inline/vllm/config.py similarity index 100% rename from llama_stack/providers/impls/vllm/config.py rename to llama_stack/providers/inline/vllm/config.py diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/inline/vllm/vllm.py similarity index 100% rename from llama_stack/providers/impls/vllm/vllm.py rename to llama_stack/providers/inline/vllm/vllm.py diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 8f4d3a03e..774dde858 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -22,8 +22,8 @@ def available_providers() -> List[ProviderSpec]: "scikit-learn", ] + kvstore_dependencies(), - module="llama_stack.providers.impls.meta_reference.agents", - config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceAgentsImplConfig", + module="llama_stack.providers.inline.meta_reference.agents", + config_class="llama_stack.providers.inline.meta_reference.agents.MetaReferenceAgentsImplConfig", api_dependencies=[ Api.inference, Api.safety, @@ -36,8 +36,8 @@ def available_providers() -> List[ProviderSpec]: adapter=AdapterSpec( adapter_type="sample", pip_packages=[], - module="llama_stack.providers.adapters.agents.sample", - config_class="llama_stack.providers.adapters.agents.sample.SampleConfig", + module="llama_stack.providers.remote.agents.sample", + config_class="llama_stack.providers.remote.agents.sample.SampleConfig", ), ), ] diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index 27e80ff57..976bbd448 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -15,8 +15,8 @@ def available_providers() -> List[ProviderSpec]: api=Api.datasetio, provider_type="meta-reference", pip_packages=["pandas"], - module="llama_stack.providers.impls.meta_reference.datasetio", - config_class="llama_stack.providers.impls.meta_reference.datasetio.MetaReferenceDatasetIOConfig", + module="llama_stack.providers.inline.meta_reference.datasetio", + config_class="llama_stack.providers.inline.meta_reference.datasetio.MetaReferenceDatasetIOConfig", api_dependencies=[], ), ] diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index fc7c923d9..9b9ba6409 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -15,8 +15,8 @@ def available_providers() -> List[ProviderSpec]: api=Api.eval, provider_type="meta-reference", pip_packages=[], - module="llama_stack.providers.impls.meta_reference.eval", - config_class="llama_stack.providers.impls.meta_reference.eval.MetaReferenceEvalConfig", + module="llama_stack.providers.inline.meta_reference.eval", + config_class="llama_stack.providers.inline.meta_reference.eval.MetaReferenceEvalConfig", api_dependencies=[ Api.datasetio, Api.datasets, diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 88265f1b4..8a3619118 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -27,8 +27,8 @@ def available_providers() -> List[ProviderSpec]: api=Api.inference, provider_type="meta-reference", pip_packages=META_REFERENCE_DEPS, - module="llama_stack.providers.impls.meta_reference.inference", - config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig", + module="llama_stack.providers.inline.meta_reference.inference", + config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceInferenceConfig", ), InlineProviderSpec( api=Api.inference, @@ -40,16 +40,16 @@ def available_providers() -> List[ProviderSpec]: "torchao==0.5.0", ] ), - module="llama_stack.providers.impls.meta_reference.inference", - config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig", + module="llama_stack.providers.inline.meta_reference.inference", + config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceQuantizedInferenceConfig", ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( adapter_type="sample", pip_packages=[], - module="llama_stack.providers.adapters.inference.sample", - config_class="llama_stack.providers.adapters.inference.sample.SampleConfig", + module="llama_stack.providers.remote.inference.sample", + config_class="llama_stack.providers.remote.inference.sample.SampleConfig", ), ), remote_provider_spec( @@ -57,26 +57,26 @@ def available_providers() -> List[ProviderSpec]: adapter=AdapterSpec( adapter_type="ollama", pip_packages=["ollama", "aiohttp"], - config_class="llama_stack.providers.adapters.inference.ollama.OllamaImplConfig", - module="llama_stack.providers.adapters.inference.ollama", + config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig", + module="llama_stack.providers.remote.inference.ollama", + ), + ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="vllm", + pip_packages=["openai"], + module="llama_stack.providers.remote.inference.vllm", + config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", ), ), - # remote_provider_spec( - # api=Api.inference, - # adapter=AdapterSpec( - # adapter_type="vllm", - # pip_packages=["openai"], - # module="llama_stack.providers.adapters.inference.vllm", - # config_class="llama_stack.providers.adapters.inference.vllm.VLLMImplConfig", - # ), - # ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( adapter_type="tgi", pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.adapters.inference.tgi", - config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig", + module="llama_stack.providers.remote.inference.tgi", + config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig", ), ), remote_provider_spec( @@ -84,8 +84,8 @@ def available_providers() -> List[ProviderSpec]: adapter=AdapterSpec( adapter_type="hf::serverless", pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.adapters.inference.tgi", - config_class="llama_stack.providers.adapters.inference.tgi.InferenceAPIImplConfig", + module="llama_stack.providers.remote.inference.tgi", + config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig", ), ), remote_provider_spec( @@ -93,8 +93,8 @@ def available_providers() -> List[ProviderSpec]: adapter=AdapterSpec( adapter_type="hf::endpoint", pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.adapters.inference.tgi", - config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig", + module="llama_stack.providers.remote.inference.tgi", + config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig", ), ), remote_provider_spec( @@ -104,8 +104,8 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[ "fireworks-ai", ], - module="llama_stack.providers.adapters.inference.fireworks", - config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig", + module="llama_stack.providers.remote.inference.fireworks", + config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig", ), ), remote_provider_spec( @@ -115,9 +115,9 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[ "together", ], - module="llama_stack.providers.adapters.inference.together", - config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig", - provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator", + module="llama_stack.providers.remote.inference.together", + config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig", + provider_data_validator="llama_stack.providers.remote.safety.together.TogetherProviderDataValidator", ), ), remote_provider_spec( @@ -125,8 +125,8 @@ def available_providers() -> List[ProviderSpec]: adapter=AdapterSpec( adapter_type="bedrock", pip_packages=["boto3"], - module="llama_stack.providers.adapters.inference.bedrock", - config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig", + module="llama_stack.providers.remote.inference.bedrock", + config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig", ), ), remote_provider_spec( @@ -136,8 +136,8 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[ "openai", ], - module="llama_stack.providers.adapters.inference.databricks", - config_class="llama_stack.providers.adapters.inference.databricks.DatabricksImplConfig", + module="llama_stack.providers.remote.inference.databricks", + config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", ), ), InlineProviderSpec( @@ -146,7 +146,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[ "vllm", ], - module="llama_stack.providers.impls.vllm", - config_class="llama_stack.providers.impls.vllm.VLLMConfig", + module="llama_stack.providers.inline.vllm", + config_class="llama_stack.providers.inline.vllm.VLLMConfig", ), ] diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index a0fbf1636..c2740017a 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -36,15 +36,15 @@ def available_providers() -> List[ProviderSpec]: api=Api.memory, provider_type="meta-reference", pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], - module="llama_stack.providers.impls.meta_reference.memory", - config_class="llama_stack.providers.impls.meta_reference.memory.FaissImplConfig", + module="llama_stack.providers.inline.meta_reference.memory", + config_class="llama_stack.providers.inline.meta_reference.memory.FaissImplConfig", ), remote_provider_spec( Api.memory, AdapterSpec( adapter_type="chromadb", pip_packages=EMBEDDING_DEPS + ["chromadb-client"], - module="llama_stack.providers.adapters.memory.chroma", + module="llama_stack.providers.remote.memory.chroma", ), ), remote_provider_spec( @@ -52,8 +52,8 @@ def available_providers() -> List[ProviderSpec]: AdapterSpec( adapter_type="pgvector", pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"], - module="llama_stack.providers.adapters.memory.pgvector", - config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig", + module="llama_stack.providers.remote.memory.pgvector", + config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig", ), ), remote_provider_spec( @@ -61,9 +61,9 @@ def available_providers() -> List[ProviderSpec]: AdapterSpec( adapter_type="weaviate", pip_packages=EMBEDDING_DEPS + ["weaviate-client"], - module="llama_stack.providers.adapters.memory.weaviate", - config_class="llama_stack.providers.adapters.memory.weaviate.WeaviateConfig", - provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData", + module="llama_stack.providers.remote.memory.weaviate", + config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig", + provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData", ), ), remote_provider_spec( @@ -71,8 +71,8 @@ def available_providers() -> List[ProviderSpec]: adapter=AdapterSpec( adapter_type="sample", pip_packages=[], - module="llama_stack.providers.adapters.memory.sample", - config_class="llama_stack.providers.adapters.memory.sample.SampleConfig", + module="llama_stack.providers.remote.memory.sample", + config_class="llama_stack.providers.remote.memory.sample.SampleConfig", ), ), remote_provider_spec( @@ -80,8 +80,8 @@ def available_providers() -> List[ProviderSpec]: AdapterSpec( adapter_type="qdrant", pip_packages=EMBEDDING_DEPS + ["qdrant-client"], - module="llama_stack.providers.adapters.memory.qdrant", - config_class="llama_stack.providers.adapters.memory.qdrant.QdrantConfig", + module="llama_stack.providers.remote.memory.qdrant", + config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig", ), ), ] diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 3fa62479a..9279d8df9 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -24,8 +24,8 @@ def available_providers() -> List[ProviderSpec]: "transformers", "torch --index-url https://download.pytorch.org/whl/cpu", ], - module="llama_stack.providers.impls.meta_reference.safety", - config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig", + module="llama_stack.providers.inline.meta_reference.safety", + config_class="llama_stack.providers.inline.meta_reference.safety.SafetyConfig", api_dependencies=[ Api.inference, ], @@ -35,8 +35,8 @@ def available_providers() -> List[ProviderSpec]: adapter=AdapterSpec( adapter_type="sample", pip_packages=[], - module="llama_stack.providers.adapters.safety.sample", - config_class="llama_stack.providers.adapters.safety.sample.SampleConfig", + module="llama_stack.providers.remote.safety.sample", + config_class="llama_stack.providers.remote.safety.sample.SampleConfig", ), ), remote_provider_spec( @@ -44,8 +44,8 @@ def available_providers() -> List[ProviderSpec]: adapter=AdapterSpec( adapter_type="bedrock", pip_packages=["boto3"], - module="llama_stack.providers.adapters.safety.bedrock", - config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig", + module="llama_stack.providers.remote.safety.bedrock", + config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", ), ), remote_provider_spec( @@ -55,9 +55,9 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[ "together", ], - module="llama_stack.providers.adapters.safety.together", - config_class="llama_stack.providers.adapters.safety.together.TogetherSafetyConfig", - provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator", + module="llama_stack.providers.remote.safety.together", + config_class="llama_stack.providers.remote.safety.together.TogetherSafetyConfig", + provider_data_validator="llama_stack.providers.remote.safety.together.TogetherProviderDataValidator", ), ), InlineProviderSpec( @@ -66,8 +66,8 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[ "codeshield", ], - module="llama_stack.providers.impls.meta_reference.codeshield", - config_class="llama_stack.providers.impls.meta_reference.codeshield.CodeShieldConfig", + module="llama_stack.providers.inline.meta_reference.codeshield", + config_class="llama_stack.providers.inline.meta_reference.codeshield.CodeShieldConfig", api_dependencies=[], ), ] diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py index 81cb47764..2586083f6 100644 --- a/llama_stack/providers/registry/scoring.py +++ b/llama_stack/providers/registry/scoring.py @@ -15,8 +15,8 @@ def available_providers() -> List[ProviderSpec]: api=Api.scoring, provider_type="meta-reference", pip_packages=[], - module="llama_stack.providers.impls.meta_reference.scoring", - config_class="llama_stack.providers.impls.meta_reference.scoring.MetaReferenceScoringConfig", + module="llama_stack.providers.inline.meta_reference.scoring", + config_class="llama_stack.providers.inline.meta_reference.scoring.MetaReferenceScoringConfig", api_dependencies=[ Api.datasetio, Api.datasets, @@ -27,8 +27,8 @@ def available_providers() -> List[ProviderSpec]: api=Api.scoring, provider_type="braintrust", pip_packages=["autoevals", "openai"], - module="llama_stack.providers.impls.braintrust.scoring", - config_class="llama_stack.providers.impls.braintrust.scoring.BraintrustScoringConfig", + module="llama_stack.providers.inline.braintrust.scoring", + config_class="llama_stack.providers.inline.braintrust.scoring.BraintrustScoringConfig", api_dependencies=[ Api.datasetio, Api.datasets, diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index 39bcb75d8..050d890aa 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -15,16 +15,16 @@ def available_providers() -> List[ProviderSpec]: api=Api.telemetry, provider_type="meta-reference", pip_packages=[], - module="llama_stack.providers.impls.meta_reference.telemetry", - config_class="llama_stack.providers.impls.meta_reference.telemetry.ConsoleConfig", + module="llama_stack.providers.inline.meta_reference.telemetry", + config_class="llama_stack.providers.inline.meta_reference.telemetry.ConsoleConfig", ), remote_provider_spec( api=Api.telemetry, adapter=AdapterSpec( adapter_type="sample", pip_packages=[], - module="llama_stack.providers.adapters.telemetry.sample", - config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig", + module="llama_stack.providers.remote.telemetry.sample", + config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig", ), ), remote_provider_spec( @@ -37,8 +37,8 @@ def available_providers() -> List[ProviderSpec]: "opentelemetry-exporter-jaeger", "opentelemetry-semantic-conventions", ], - module="llama_stack.providers.adapters.telemetry.opentelemetry", - config_class="llama_stack.providers.adapters.telemetry.opentelemetry.OpenTelemetryConfig", + module="llama_stack.providers.remote.telemetry.opentelemetry", + config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig", ), ), ] diff --git a/llama_stack/providers/impls/meta_reference/agents/tools/__init__.py b/llama_stack/providers/remote/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/tools/__init__.py rename to llama_stack/providers/remote/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/__init__.py b/llama_stack/providers/remote/agents/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/agents/tools/ipython_tool/__init__.py rename to llama_stack/providers/remote/agents/__init__.py diff --git a/llama_stack/providers/adapters/agents/sample/__init__.py b/llama_stack/providers/remote/agents/sample/__init__.py similarity index 100% rename from llama_stack/providers/adapters/agents/sample/__init__.py rename to llama_stack/providers/remote/agents/sample/__init__.py diff --git a/llama_stack/providers/adapters/agents/sample/config.py b/llama_stack/providers/remote/agents/sample/config.py similarity index 100% rename from llama_stack/providers/adapters/agents/sample/config.py rename to llama_stack/providers/remote/agents/sample/config.py diff --git a/llama_stack/providers/adapters/agents/sample/sample.py b/llama_stack/providers/remote/agents/sample/sample.py similarity index 100% rename from llama_stack/providers/adapters/agents/sample/sample.py rename to llama_stack/providers/remote/agents/sample/sample.py diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/__init__.py b/llama_stack/providers/remote/inference/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/inference/quantization/__init__.py rename to llama_stack/providers/remote/inference/__init__.py diff --git a/llama_stack/providers/adapters/inference/bedrock/__init__.py b/llama_stack/providers/remote/inference/bedrock/__init__.py similarity index 100% rename from llama_stack/providers/adapters/inference/bedrock/__init__.py rename to llama_stack/providers/remote/inference/bedrock/__init__.py diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py similarity index 92% rename from llama_stack/providers/adapters/inference/bedrock/bedrock.py rename to llama_stack/providers/remote/inference/bedrock/bedrock.py index caf886c0b..f569e0093 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -6,9 +6,7 @@ from typing import * # noqa: F403 -import boto3 from botocore.client import BaseClient -from botocore.config import Config from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -16,7 +14,9 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig + +from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig +from llama_stack.providers.utils.bedrock.client import create_bedrock_client BEDROCK_SUPPORTED_MODELS = { @@ -34,7 +34,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) self._config = config - self._client = _create_bedrock_client(config) + self._client = create_bedrock_client(config) self.formatter = ChatFormat(Tokenizer.get_instance()) @property @@ -437,43 +437,3 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() - - -def _create_bedrock_client(config: BedrockConfig) -> BaseClient: - retries_config = { - k: v - for k, v in dict( - total_max_attempts=config.total_max_attempts, - mode=config.retry_mode, - ).items() - if v is not None - } - - config_args = { - k: v - for k, v in dict( - region_name=config.region_name, - retries=retries_config if retries_config else None, - connect_timeout=config.connect_timeout, - read_timeout=config.read_timeout, - ).items() - if v is not None - } - - boto3_config = Config(**config_args) - - session_args = { - k: v - for k, v in dict( - aws_access_key_id=config.aws_access_key_id, - aws_secret_access_key=config.aws_secret_access_key, - aws_session_token=config.aws_session_token, - region_name=config.region_name, - profile_name=config.profile_name, - ).items() - if v is not None - } - - boto3_session = boto3.session.Session(**session_args) - - return boto3_session.client("bedrock-runtime", config=boto3_config) diff --git a/llama_stack/providers/remote/inference/bedrock/config.py b/llama_stack/providers/remote/inference/bedrock/config.py new file mode 100644 index 000000000..8e194700c --- /dev/null +++ b/llama_stack/providers/remote/inference/bedrock/config.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_models.schema_utils import json_schema_type + +from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig + + +@json_schema_type +class BedrockConfig(BedrockBaseConfig): + pass diff --git a/llama_stack/providers/adapters/inference/databricks/__init__.py b/llama_stack/providers/remote/inference/databricks/__init__.py similarity index 100% rename from llama_stack/providers/adapters/inference/databricks/__init__.py rename to llama_stack/providers/remote/inference/databricks/__init__.py diff --git a/llama_stack/providers/adapters/inference/databricks/config.py b/llama_stack/providers/remote/inference/databricks/config.py similarity index 100% rename from llama_stack/providers/adapters/inference/databricks/config.py rename to llama_stack/providers/remote/inference/databricks/config.py diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py similarity index 100% rename from llama_stack/providers/adapters/inference/databricks/databricks.py rename to llama_stack/providers/remote/inference/databricks/databricks.py diff --git a/llama_stack/providers/adapters/inference/fireworks/__init__.py b/llama_stack/providers/remote/inference/fireworks/__init__.py similarity index 100% rename from llama_stack/providers/adapters/inference/fireworks/__init__.py rename to llama_stack/providers/remote/inference/fireworks/__init__.py diff --git a/llama_stack/providers/adapters/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py similarity index 100% rename from llama_stack/providers/adapters/inference/fireworks/config.py rename to llama_stack/providers/remote/inference/fireworks/config.py diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py similarity index 78% rename from llama_stack/providers/adapters/inference/fireworks/fireworks.py rename to llama_stack/providers/remote/inference/fireworks/fireworks.py index 5b5a03196..0070756d8 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -26,6 +26,8 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + convert_message_to_dict, + request_has_media, ) from .config import FireworksImplConfig @@ -82,14 +84,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): async def _nonstream_completion( self, request: CompletionRequest, client: Fireworks ) -> CompletionResponse: - params = self._get_params(request) + params = await self._get_params(request) r = await client.completion.acreate(**params) return process_completion_response(r, self.formatter) async def _stream_completion( self, request: CompletionRequest, client: Fireworks ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) stream = client.completion.acreate(**params) async for chunk in process_completion_stream_response(stream, self.formatter): @@ -128,33 +130,55 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: Fireworks ) -> ChatCompletionResponse: - params = self._get_params(request) - r = await client.completion.acreate(**params) + params = await self._get_params(request) + if "messages" in params: + r = await client.chat.completions.acreate(**params) + else: + r = await client.completion.acreate(**params) return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest, client: Fireworks ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) + + if "messages" in params: + stream = client.chat.completions.acreate(**params) + else: + stream = client.completion.acreate(**params) - stream = client.completion.acreate(**params) async for chunk in process_chat_completion_stream_response( stream, self.formatter ): yield chunk - def _get_params(self, request) -> dict: - prompt = "" - if type(request) == ChatCompletionRequest: - prompt = chat_completion_request_to_prompt(request, self.formatter) - elif type(request) == CompletionRequest: - prompt = completion_request_to_prompt(request, self.formatter) + async def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + input_dict = {} + media_present = request_has_media(request) + + if isinstance(request, ChatCompletionRequest): + if media_present: + input_dict["messages"] = [ + await convert_message_to_dict(m) for m in request.messages + ] + else: + input_dict["prompt"] = chat_completion_request_to_prompt( + request, self.formatter + ) + elif isinstance(request, CompletionRequest): + assert ( + not media_present + ), "Fireworks does not support media for Completion requests" + input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) else: raise ValueError(f"Unknown request type {type(request)}") # Fireworks always prepends with BOS - if prompt.startswith("<|begin_of_text|>"): - prompt = prompt[len("<|begin_of_text|>") :] + if "prompt" in input_dict: + if input_dict["prompt"].startswith("<|begin_of_text|>"): + input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] options = get_sampling_options(request.sampling_params) options.setdefault("max_tokens", 512) @@ -172,9 +196,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): } else: raise ValueError(f"Unknown response format {fmt.type}") + return { "model": self.map_to_provider_model(request.model), - "prompt": prompt, + **input_dict, "stream": request.stream, **options, } diff --git a/llama_stack/providers/adapters/inference/ollama/__init__.py b/llama_stack/providers/remote/inference/ollama/__init__.py similarity index 100% rename from llama_stack/providers/adapters/inference/ollama/__init__.py rename to llama_stack/providers/remote/inference/ollama/__init__.py diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py similarity index 68% rename from llama_stack/providers/adapters/inference/ollama/ollama.py rename to llama_stack/providers/remote/inference/ollama/ollama.py index 916241a7c..3530e1234 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -29,6 +29,8 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + convert_image_media_to_url, + request_has_media, ) OLLAMA_SUPPORTED_MODELS = { @@ -38,6 +40,7 @@ OLLAMA_SUPPORTED_MODELS = { "Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16", "Llama-Guard-3-8B": "llama-guard3:8b", "Llama-Guard-3-1B": "llama-guard3:1b", + "Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16", } @@ -109,22 +112,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): else: return await self._nonstream_completion(request) - def _get_params_for_completion(self, request: CompletionRequest) -> dict: - sampling_options = get_sampling_options(request.sampling_params) - # This is needed since the Ollama API expects num_predict to be set - # for early truncation instead of max_tokens. - if sampling_options["max_tokens"] is not None: - sampling_options["num_predict"] = sampling_options["max_tokens"] - return { - "model": OLLAMA_SUPPORTED_MODELS[request.model], - "prompt": completion_request_to_prompt(request, self.formatter), - "options": sampling_options, - "raw": True, - "stream": request.stream, - } - async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = self._get_params_for_completion(request) + params = await self._get_params(request) async def _generate_and_convert_to_openai_compat(): s = await self.client.generate(**params) @@ -142,7 +131,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = self._get_params_for_completion(request) + params = await self._get_params(request) r = await self.client.generate(**params) assert isinstance(r, dict) @@ -183,26 +172,66 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): else: return await self._nonstream_chat_completion(request) - def _get_params(self, request: ChatCompletionRequest) -> dict: + async def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + sampling_options = get_sampling_options(request.sampling_params) + # This is needed since the Ollama API expects num_predict to be set + # for early truncation instead of max_tokens. + if sampling_options.get("max_tokens") is not None: + sampling_options["num_predict"] = sampling_options["max_tokens"] + + input_dict = {} + media_present = request_has_media(request) + if isinstance(request, ChatCompletionRequest): + if media_present: + contents = [ + await convert_message_to_dict_for_ollama(m) + for m in request.messages + ] + # flatten the list of lists + input_dict["messages"] = [ + item for sublist in contents for item in sublist + ] + else: + input_dict["raw"] = True + input_dict["prompt"] = chat_completion_request_to_prompt( + request, self.formatter + ) + else: + assert ( + not media_present + ), "Ollama does not support media for Completion requests" + input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) + input_dict["raw"] = True + return { "model": OLLAMA_SUPPORTED_MODELS[request.model], - "prompt": chat_completion_request_to_prompt(request, self.formatter), - "options": get_sampling_options(request.sampling_params), - "raw": True, + **input_dict, + "options": sampling_options, "stream": request.stream, } async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - params = self._get_params(request) - r = await self.client.generate(**params) + params = await self._get_params(request) + if "messages" in params: + r = await self.client.chat(**params) + else: + r = await self.client.generate(**params) assert isinstance(r, dict) - choice = OpenAICompatCompletionChoice( - finish_reason=r["done_reason"] if r["done"] else None, - text=r["response"], - ) + if "message" in r: + choice = OpenAICompatCompletionChoice( + finish_reason=r["done_reason"] if r["done"] else None, + text=r["message"]["content"], + ) + else: + choice = OpenAICompatCompletionChoice( + finish_reason=r["done_reason"] if r["done"] else None, + text=r["response"], + ) response = OpenAICompatCompletionResponse( choices=[choice], ) @@ -211,15 +240,24 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) async def _generate_and_convert_to_openai_compat(): - s = await self.client.generate(**params) + if "messages" in params: + s = await self.client.chat(**params) + else: + s = await self.client.generate(**params) async for chunk in s: - choice = OpenAICompatCompletionChoice( - finish_reason=chunk["done_reason"] if chunk["done"] else None, - text=chunk["response"], - ) + if "message" in chunk: + choice = OpenAICompatCompletionChoice( + finish_reason=chunk["done_reason"] if chunk["done"] else None, + text=chunk["message"]["content"], + ) + else: + choice = OpenAICompatCompletionChoice( + finish_reason=chunk["done_reason"] if chunk["done"] else None, + text=chunk["response"], + ) yield OpenAICompatCompletionResponse( choices=[choice], ) @@ -236,3 +274,26 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() + + +async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]: + async def _convert_content(content) -> dict: + if isinstance(content, ImageMedia): + return { + "role": message.role, + "images": [ + await convert_image_media_to_url( + content, download=True, include_format=False + ) + ], + } + else: + return { + "role": message.role, + "content": content, + } + + if isinstance(message.content, list): + return [await _convert_content(c) for c in message.content] + else: + return [await _convert_content(message.content)] diff --git a/llama_stack/providers/adapters/inference/sample/__init__.py b/llama_stack/providers/remote/inference/sample/__init__.py similarity index 100% rename from llama_stack/providers/adapters/inference/sample/__init__.py rename to llama_stack/providers/remote/inference/sample/__init__.py diff --git a/llama_stack/providers/adapters/inference/sample/config.py b/llama_stack/providers/remote/inference/sample/config.py similarity index 100% rename from llama_stack/providers/adapters/inference/sample/config.py rename to llama_stack/providers/remote/inference/sample/config.py diff --git a/llama_stack/providers/adapters/inference/sample/sample.py b/llama_stack/providers/remote/inference/sample/sample.py similarity index 100% rename from llama_stack/providers/adapters/inference/sample/sample.py rename to llama_stack/providers/remote/inference/sample/sample.py diff --git a/llama_stack/providers/adapters/inference/tgi/__init__.py b/llama_stack/providers/remote/inference/tgi/__init__.py similarity index 100% rename from llama_stack/providers/adapters/inference/tgi/__init__.py rename to llama_stack/providers/remote/inference/tgi/__init__.py diff --git a/llama_stack/providers/adapters/inference/tgi/config.py b/llama_stack/providers/remote/inference/tgi/config.py similarity index 89% rename from llama_stack/providers/adapters/inference/tgi/config.py rename to llama_stack/providers/remote/inference/tgi/config.py index 6ce2b9dc6..863f81bf7 100644 --- a/llama_stack/providers/adapters/inference/tgi/config.py +++ b/llama_stack/providers/remote/inference/tgi/config.py @@ -12,9 +12,14 @@ from pydantic import BaseModel, Field @json_schema_type class TGIImplConfig(BaseModel): - url: str = Field( - description="The URL for the TGI endpoint (e.g. 'http://localhost:8080')", - ) + host: str = "localhost" + port: int = 8080 + protocol: str = "http" + + @property + def url(self) -> str: + return f"{self.protocol}://{self.host}:{self.port}" + api_token: Optional[str] = Field( default=None, description="A bearer token if your TGI endpoint is protected.", diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py similarity index 100% rename from llama_stack/providers/adapters/inference/tgi/tgi.py rename to llama_stack/providers/remote/inference/tgi/tgi.py diff --git a/llama_stack/providers/adapters/inference/together/__init__.py b/llama_stack/providers/remote/inference/together/__init__.py similarity index 100% rename from llama_stack/providers/adapters/inference/together/__init__.py rename to llama_stack/providers/remote/inference/together/__init__.py diff --git a/llama_stack/providers/adapters/inference/together/config.py b/llama_stack/providers/remote/inference/together/config.py similarity index 100% rename from llama_stack/providers/adapters/inference/together/config.py rename to llama_stack/providers/remote/inference/together/config.py diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py similarity index 82% rename from llama_stack/providers/adapters/inference/together/together.py rename to llama_stack/providers/remote/inference/together/together.py index 5decea482..28a566415 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -26,6 +26,8 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + convert_message_to_dict, + request_has_media, ) from .config import TogetherImplConfig @@ -97,12 +99,12 @@ class TogetherInferenceAdapter( async def _nonstream_completion( self, request: CompletionRequest ) -> ChatCompletionResponse: - params = self._get_params_for_completion(request) + params = await self._get_params(request) r = self._get_client().completions.create(**params) return process_completion_response(r, self.formatter) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = self._get_params_for_completion(request) + params = await self._get_params(request) # if we shift to TogetherAsyncClient, we won't need this wrapper async def _to_async_generator(): @@ -131,14 +133,6 @@ class TogetherInferenceAdapter( return options - def _get_params_for_completion(self, request: CompletionRequest) -> dict: - return { - "model": self.map_to_provider_model(request.model), - "prompt": completion_request_to_prompt(request, self.formatter), - "stream": request.stream, - **self._build_options(request.sampling_params, request.response_format), - } - async def chat_completion( self, model: str, @@ -171,18 +165,24 @@ class TogetherInferenceAdapter( async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - params = self._get_params(request) - r = self._get_client().completions.create(**params) + params = await self._get_params(request) + if "messages" in params: + r = self._get_client().chat.completions.create(**params) + else: + r = self._get_client().completions.create(**params) return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) # if we shift to TogetherAsyncClient, we won't need this wrapper async def _to_async_generator(): - s = self._get_client().completions.create(**params) + if "messages" in params: + s = self._get_client().chat.completions.create(**params) + else: + s = self._get_client().completions.create(**params) for chunk in s: yield chunk @@ -192,10 +192,29 @@ class TogetherInferenceAdapter( ): yield chunk - def _get_params(self, request: ChatCompletionRequest) -> dict: + async def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + input_dict = {} + media_present = request_has_media(request) + if isinstance(request, ChatCompletionRequest): + if media_present: + input_dict["messages"] = [ + await convert_message_to_dict(m) for m in request.messages + ] + else: + input_dict["prompt"] = chat_completion_request_to_prompt( + request, self.formatter + ) + else: + assert ( + not media_present + ), "Together does not support media for Completion requests" + input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) + return { "model": self.map_to_provider_model(request.model), - "prompt": chat_completion_request_to_prompt(request, self.formatter), + **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.response_format), } diff --git a/llama_stack/providers/adapters/inference/vllm/__init__.py b/llama_stack/providers/remote/inference/vllm/__init__.py similarity index 50% rename from llama_stack/providers/adapters/inference/vllm/__init__.py rename to llama_stack/providers/remote/inference/vllm/__init__.py index f4588a307..78222d7d9 100644 --- a/llama_stack/providers/adapters/inference/vllm/__init__.py +++ b/llama_stack/providers/remote/inference/vllm/__init__.py @@ -4,12 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import VLLMImplConfig -from .vllm import VLLMInferenceAdapter +from .config import VLLMInferenceAdapterConfig -async def get_adapter_impl(config: VLLMImplConfig, _deps): - assert isinstance(config, VLLMImplConfig), f"Unexpected config type: {type(config)}" +async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps): + from .vllm import VLLMInferenceAdapter + + assert isinstance( + config, VLLMInferenceAdapterConfig + ), f"Unexpected config type: {type(config)}" impl = VLLMInferenceAdapter(config) await impl.initialize() return impl diff --git a/llama_stack/providers/adapters/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py similarity index 74% rename from llama_stack/providers/adapters/inference/vllm/config.py rename to llama_stack/providers/remote/inference/vllm/config.py index 65815922c..50a174589 100644 --- a/llama_stack/providers/adapters/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -11,12 +11,16 @@ from pydantic import BaseModel, Field @json_schema_type -class VLLMImplConfig(BaseModel): +class VLLMInferenceAdapterConfig(BaseModel): url: Optional[str] = Field( default=None, description="The URL for the vLLM model serving endpoint", ) + max_tokens: int = Field( + default=4096, + description="Maximum number of tokens to generate.", + ) api_token: Optional[str] = Field( - default=None, + default="fake", description="The API token", ) diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py similarity index 67% rename from llama_stack/providers/adapters/inference/vllm/vllm.py rename to llama_stack/providers/remote/inference/vllm/vllm.py index aad2fdc1f..0259c7061 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.sku_list import all_registered_models, resolve_model from openai import OpenAI @@ -23,42 +24,19 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, ) -from .config import VLLMImplConfig - -VLLM_SUPPORTED_MODELS = { - "Llama3.1-8B": "meta-llama/Llama-3.1-8B", - "Llama3.1-70B": "meta-llama/Llama-3.1-70B", - "Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B", - "Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8", - "Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B", - "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", - "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", - "Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct", - "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8", - "Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct", - "Llama3.2-1B": "meta-llama/Llama-3.2-1B", - "Llama3.2-3B": "meta-llama/Llama-3.2-3B", - "Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision", - "Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision", - "Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct", - "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", - "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct", - "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct", - "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision", - "Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4", - "Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B", - "Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B", - "Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8", - "Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M", - "Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B", -} +from .config import VLLMInferenceAdapterConfig class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): - def __init__(self, config: VLLMImplConfig) -> None: + def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = None + self.huggingface_repo_to_llama_model_id = { + model.huggingface_repo: model.descriptor() + for model in all_registered_models() + if model.huggingface_repo + } async def initialize(self) -> None: self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) @@ -70,10 +48,21 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): pass async def list_models(self) -> List[ModelDef]: - return [ - ModelDef(identifier=model.id, llama_model=model.id) - for model in self.client.models.list() - ] + models = [] + for model in self.client.models.list(): + repo = model.id + if repo not in self.huggingface_repo_to_llama_model_id: + print(f"Unknown model served by vllm: {repo}") + continue + + identifier = self.huggingface_repo_to_llama_model_id[repo] + models.append( + ModelDef( + identifier=identifier, + llama_model=identifier, + ) + ) + return models async def completion( self, @@ -118,7 +107,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(request, r, self.formatter) + return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI @@ -139,11 +128,19 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: + options = get_sampling_options(request.sampling_params) + if "max_tokens" not in options: + options["max_tokens"] = self.config.max_tokens + + model = resolve_model(request.model) + if model is None: + raise ValueError(f"Unknown model: {request.model}") + return { - "model": VLLM_SUPPORTED_MODELS[request.model], + "model": model.huggingface_repo, "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, - **get_sampling_options(request.sampling_params), + **options, } async def embeddings( diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/scripts/__init__.py b/llama_stack/providers/remote/memory/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/inference/quantization/scripts/__init__.py rename to llama_stack/providers/remote/memory/__init__.py diff --git a/llama_stack/providers/adapters/memory/chroma/__init__.py b/llama_stack/providers/remote/memory/chroma/__init__.py similarity index 100% rename from llama_stack/providers/adapters/memory/chroma/__init__.py rename to llama_stack/providers/remote/memory/chroma/__init__.py diff --git a/llama_stack/providers/adapters/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py similarity index 100% rename from llama_stack/providers/adapters/memory/chroma/chroma.py rename to llama_stack/providers/remote/memory/chroma/chroma.py diff --git a/llama_stack/providers/adapters/memory/pgvector/__init__.py b/llama_stack/providers/remote/memory/pgvector/__init__.py similarity index 100% rename from llama_stack/providers/adapters/memory/pgvector/__init__.py rename to llama_stack/providers/remote/memory/pgvector/__init__.py diff --git a/llama_stack/providers/adapters/memory/pgvector/config.py b/llama_stack/providers/remote/memory/pgvector/config.py similarity index 75% rename from llama_stack/providers/adapters/memory/pgvector/config.py rename to llama_stack/providers/remote/memory/pgvector/config.py index 87b2f4a3b..41983e7b2 100644 --- a/llama_stack/providers/adapters/memory/pgvector/config.py +++ b/llama_stack/providers/remote/memory/pgvector/config.py @@ -12,6 +12,6 @@ from pydantic import BaseModel, Field class PGVectorConfig(BaseModel): host: str = Field(default="localhost") port: int = Field(default=5432) - db: str - user: str - password: str + db: str = Field(default="postgres") + user: str = Field(default="postgres") + password: str = Field(default="mysecretpassword") diff --git a/llama_stack/providers/adapters/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py similarity index 100% rename from llama_stack/providers/adapters/memory/pgvector/pgvector.py rename to llama_stack/providers/remote/memory/pgvector/pgvector.py diff --git a/llama_stack/providers/adapters/memory/qdrant/__init__.py b/llama_stack/providers/remote/memory/qdrant/__init__.py similarity index 100% rename from llama_stack/providers/adapters/memory/qdrant/__init__.py rename to llama_stack/providers/remote/memory/qdrant/__init__.py diff --git a/llama_stack/providers/adapters/memory/qdrant/config.py b/llama_stack/providers/remote/memory/qdrant/config.py similarity index 100% rename from llama_stack/providers/adapters/memory/qdrant/config.py rename to llama_stack/providers/remote/memory/qdrant/config.py diff --git a/llama_stack/providers/adapters/memory/qdrant/qdrant.py b/llama_stack/providers/remote/memory/qdrant/qdrant.py similarity index 98% rename from llama_stack/providers/adapters/memory/qdrant/qdrant.py rename to llama_stack/providers/remote/memory/qdrant/qdrant.py index 45a8024ac..0f0df3dca 100644 --- a/llama_stack/providers/adapters/memory/qdrant/qdrant.py +++ b/llama_stack/providers/remote/memory/qdrant/qdrant.py @@ -16,7 +16,7 @@ from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.providers.adapters.memory.qdrant.config import QdrantConfig +from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, diff --git a/llama_stack/providers/adapters/memory/sample/__init__.py b/llama_stack/providers/remote/memory/sample/__init__.py similarity index 100% rename from llama_stack/providers/adapters/memory/sample/__init__.py rename to llama_stack/providers/remote/memory/sample/__init__.py diff --git a/llama_stack/providers/adapters/memory/sample/config.py b/llama_stack/providers/remote/memory/sample/config.py similarity index 100% rename from llama_stack/providers/adapters/memory/sample/config.py rename to llama_stack/providers/remote/memory/sample/config.py diff --git a/llama_stack/providers/adapters/memory/sample/sample.py b/llama_stack/providers/remote/memory/sample/sample.py similarity index 100% rename from llama_stack/providers/adapters/memory/sample/sample.py rename to llama_stack/providers/remote/memory/sample/sample.py diff --git a/llama_stack/providers/adapters/memory/weaviate/__init__.py b/llama_stack/providers/remote/memory/weaviate/__init__.py similarity index 100% rename from llama_stack/providers/adapters/memory/weaviate/__init__.py rename to llama_stack/providers/remote/memory/weaviate/__init__.py diff --git a/llama_stack/providers/adapters/memory/weaviate/config.py b/llama_stack/providers/remote/memory/weaviate/config.py similarity index 100% rename from llama_stack/providers/adapters/memory/weaviate/config.py rename to llama_stack/providers/remote/memory/weaviate/config.py diff --git a/llama_stack/providers/adapters/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py similarity index 100% rename from llama_stack/providers/adapters/memory/weaviate/weaviate.py rename to llama_stack/providers/remote/memory/weaviate/weaviate.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/__init__.py b/llama_stack/providers/remote/safety/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/scoring/scoring_fn/__init__.py rename to llama_stack/providers/remote/safety/__init__.py diff --git a/llama_stack/providers/adapters/safety/bedrock/__init__.py b/llama_stack/providers/remote/safety/bedrock/__init__.py similarity index 100% rename from llama_stack/providers/adapters/safety/bedrock/__init__.py rename to llama_stack/providers/remote/safety/bedrock/__init__.py diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py similarity index 68% rename from llama_stack/providers/adapters/safety/bedrock/bedrock.py rename to llama_stack/providers/remote/safety/bedrock/bedrock.py index 3203e36f4..e14dbd2a4 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -9,11 +9,10 @@ import logging from typing import Any, Dict, List -import boto3 - from llama_stack.apis.safety import * # noqa from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import ShieldsProtocolPrivate +from llama_stack.providers.utils.bedrock.client import create_bedrock_client from .config import BedrockSafetyConfig @@ -28,17 +27,13 @@ BEDROCK_SUPPORTED_SHIELDS = [ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): def __init__(self, config: BedrockSafetyConfig) -> None: - if not config.aws_profile: - raise ValueError(f"Missing boto_client aws_profile in model info::{config}") self.config = config self.registered_shields = [] async def initialize(self) -> None: try: - print(f"initializing with profile --- > {self.config}") - self.boto_client = boto3.Session( - profile_name=self.config.aws_profile - ).client("bedrock-runtime") + self.bedrock_runtime_client = create_bedrock_client(self.config) + self.bedrock_client = create_bedrock_client(self.config, "bedrock") except Exception as e: raise RuntimeError("Error initializing BedrockSafetyAdapter") from e @@ -49,19 +44,28 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): raise ValueError("Registering dynamic shields is not supported") async def list_shields(self) -> List[ShieldDef]: - raise NotImplementedError( - """ - `list_shields` not implemented; this should read all guardrails from - bedrock and populate guardrailId and guardrailVersion in the ShieldDef. - """ - ) + response = self.bedrock_client.list_guardrails() + shields = [] + for guardrail in response["guardrails"]: + # populate the shield def with the guardrail id and version + shield_def = ShieldDef( + identifier=guardrail["id"], + shield_type=ShieldType.generic_content_shield.value, + params={ + "guardrailIdentifier": guardrail["id"], + "guardrailVersion": guardrail["version"], + }, + ) + self.registered_shields.append(shield_def) + shields.append(shield_def) + return shields async def run_shield( - self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None + self, identifier: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: - shield_def = await self.shield_store.get_shield(shield_type) + shield_def = await self.shield_store.get_shield(identifier) if not shield_def: - raise ValueError(f"Unknown shield {shield_type}") + raise ValueError(f"Unknown shield {identifier}") """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format ```content = [ @@ -88,7 +92,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:" ) - response = self.boto_client.apply_guardrail( + response = self.bedrock_runtime_client.apply_guardrail( guardrailIdentifier=shield_params["guardrailIdentifier"], guardrailVersion=shield_params["guardrailVersion"], source="OUTPUT", # or 'INPUT' depending on your use case @@ -104,10 +108,12 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): # guardrails returns a list - however for this implementation we will leverage the last values metadata = dict(assessment) - return SafetyViolation( - user_message=user_message, - violation_level=ViolationLevel.ERROR, - metadata=metadata, + return RunShieldResponse( + violation=SafetyViolation( + user_message=user_message, + violation_level=ViolationLevel.ERROR, + metadata=metadata, + ) ) - return None + return RunShieldResponse() diff --git a/llama_stack/providers/impls/meta_reference/memory/config.py b/llama_stack/providers/remote/safety/bedrock/config.py similarity index 68% rename from llama_stack/providers/impls/meta_reference/memory/config.py rename to llama_stack/providers/remote/safety/bedrock/config.py index b1c94c889..8c61decf3 100644 --- a/llama_stack/providers/impls/meta_reference/memory/config.py +++ b/llama_stack/providers/remote/safety/bedrock/config.py @@ -4,10 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel +from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig @json_schema_type -class FaissImplConfig(BaseModel): ... +class BedrockSafetyConfig(BedrockBaseConfig): + pass diff --git a/llama_stack/providers/adapters/safety/sample/__init__.py b/llama_stack/providers/remote/safety/sample/__init__.py similarity index 100% rename from llama_stack/providers/adapters/safety/sample/__init__.py rename to llama_stack/providers/remote/safety/sample/__init__.py diff --git a/llama_stack/providers/adapters/safety/sample/config.py b/llama_stack/providers/remote/safety/sample/config.py similarity index 100% rename from llama_stack/providers/adapters/safety/sample/config.py rename to llama_stack/providers/remote/safety/sample/config.py diff --git a/llama_stack/providers/adapters/safety/sample/sample.py b/llama_stack/providers/remote/safety/sample/sample.py similarity index 100% rename from llama_stack/providers/adapters/safety/sample/sample.py rename to llama_stack/providers/remote/safety/sample/sample.py diff --git a/llama_stack/providers/adapters/safety/together/__init__.py b/llama_stack/providers/remote/safety/together/__init__.py similarity index 100% rename from llama_stack/providers/adapters/safety/together/__init__.py rename to llama_stack/providers/remote/safety/together/__init__.py diff --git a/llama_stack/providers/adapters/safety/together/config.py b/llama_stack/providers/remote/safety/together/config.py similarity index 100% rename from llama_stack/providers/adapters/safety/together/config.py rename to llama_stack/providers/remote/safety/together/config.py diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/remote/safety/together/together.py similarity index 93% rename from llama_stack/providers/adapters/safety/together/together.py rename to llama_stack/providers/remote/safety/together/together.py index da45ed5b8..9f92626af 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/remote/safety/together/together.py @@ -43,11 +43,11 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat ] async def run_shield( - self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None + self, identifier: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: - shield_def = await self.shield_store.get_shield(shield_type) + shield_def = await self.shield_store.get_shield(identifier) if not shield_def: - raise ValueError(f"Unknown shield {shield_type}") + raise ValueError(f"Unknown shield {identifier}") model = shield_def.params.get("model", "llama_guard") if model not in TOGETHER_SHIELD_MODEL_MAP: diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/__init__.py b/llama_stack/providers/remote/telemetry/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/__init__.py rename to llama_stack/providers/remote/telemetry/__init__.py diff --git a/llama_stack/providers/adapters/telemetry/opentelemetry/__init__.py b/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py similarity index 100% rename from llama_stack/providers/adapters/telemetry/opentelemetry/__init__.py rename to llama_stack/providers/remote/telemetry/opentelemetry/__init__.py diff --git a/llama_stack/providers/adapters/telemetry/opentelemetry/config.py b/llama_stack/providers/remote/telemetry/opentelemetry/config.py similarity index 100% rename from llama_stack/providers/adapters/telemetry/opentelemetry/config.py rename to llama_stack/providers/remote/telemetry/opentelemetry/config.py diff --git a/llama_stack/providers/adapters/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py similarity index 100% rename from llama_stack/providers/adapters/telemetry/opentelemetry/opentelemetry.py rename to llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py diff --git a/llama_stack/providers/adapters/telemetry/sample/__init__.py b/llama_stack/providers/remote/telemetry/sample/__init__.py similarity index 100% rename from llama_stack/providers/adapters/telemetry/sample/__init__.py rename to llama_stack/providers/remote/telemetry/sample/__init__.py diff --git a/llama_stack/providers/adapters/telemetry/sample/config.py b/llama_stack/providers/remote/telemetry/sample/config.py similarity index 100% rename from llama_stack/providers/adapters/telemetry/sample/config.py rename to llama_stack/providers/remote/telemetry/sample/config.py diff --git a/llama_stack/providers/adapters/telemetry/sample/sample.py b/llama_stack/providers/remote/telemetry/sample/sample.py similarity index 100% rename from llama_stack/providers/adapters/telemetry/sample/sample.py rename to llama_stack/providers/remote/telemetry/sample/sample.py diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 153ade0da..86ecae1e9 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -11,7 +11,7 @@ import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.impls.meta_reference.agents import ( +from llama_stack.providers.inline.meta_reference.agents import ( MetaReferenceAgentsImplConfig, ) diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index 71253871d..ba60b9925 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -19,12 +19,11 @@ def pytest_addoption(parser): def pytest_configure(config): - config.addinivalue_line( - "markers", "llama_8b: mark test to run only with the given model" - ) - config.addinivalue_line( - "markers", "llama_3b: mark test to run only with the given model" - ) + for model in ["llama_8b", "llama_3b", "llama_vision"]: + config.addinivalue_line( + "markers", f"{model}: mark test to run only with the given model" + ) + for fixture_name in INFERENCE_FIXTURES: config.addinivalue_line( "markers", @@ -37,6 +36,14 @@ MODEL_PARAMS = [ pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"), ] +VISION_MODEL_PARAMS = [ + pytest.param( + "Llama3.2-11B-Vision-Instruct", + marks=pytest.mark.llama_vision, + id="llama_vision", + ), +] + def pytest_generate_tests(metafunc): if "inference_model" in metafunc.fixturenames: @@ -44,7 +51,11 @@ def pytest_generate_tests(metafunc): if model: params = [pytest.param(model, id="")] else: - params = MODEL_PARAMS + cls_name = metafunc.cls.__name__ + if "Vision" in cls_name: + params = VISION_MODEL_PARAMS + else: + params = MODEL_PARAMS metafunc.parametrize( "inference_model", diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 896acbad8..9db70888e 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -10,14 +10,16 @@ import pytest import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider - -from llama_stack.providers.adapters.inference.fireworks import FireworksImplConfig -from llama_stack.providers.adapters.inference.ollama import OllamaImplConfig -from llama_stack.providers.adapters.inference.together import TogetherImplConfig -from llama_stack.providers.impls.meta_reference.inference import ( +from llama_stack.providers.inline.meta_reference.inference import ( MetaReferenceInferenceConfig, ) + +from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig +from llama_stack.providers.remote.inference.ollama import OllamaImplConfig +from llama_stack.providers.remote.inference.together import TogetherImplConfig +from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 + from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail @@ -78,6 +80,21 @@ def inference_ollama(inference_model) -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_vllm_remote() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="remote::vllm", + provider_type="remote::vllm", + config=VLLMInferenceAdapterConfig( + url=get_env_or_fail("VLLM_URL"), + ).model_dump(), + ) + ], + ) + + @pytest.fixture(scope="session") def inference_fireworks() -> ProviderFixture: return ProviderFixture( @@ -109,7 +126,14 @@ def inference_together() -> ProviderFixture: ) -INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together", "remote"] +INFERENCE_FIXTURES = [ + "meta_reference", + "ollama", + "fireworks", + "together", + "vllm_remote", + "remote", +] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/inference/pasta.jpeg b/llama_stack/providers/tests/inference/pasta.jpeg new file mode 100644 index 000000000..e8299321c Binary files /dev/null and b/llama_stack/providers/tests/inference/pasta.jpeg differ diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 29fdc43a4..342117536 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import itertools import pytest @@ -15,6 +14,9 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +from .utils import group_chunks + + # How to run this test: # # pytest -v -s llama_stack/providers/tests/inference/test_inference.py @@ -22,15 +24,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403 # --env FIREWORKS_API_KEY= -def group_chunks(response): - return { - event_type: list(group) - for event_type, group in itertools.groupby( - response, key=lambda chunk: chunk.event.event_type - ) - } - - def get_expected_stop_reason(model: str): return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py new file mode 100644 index 000000000..1939d6934 --- /dev/null +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +import pytest +from PIL import Image as PIL_Image + + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 + +from .utils import group_chunks + +THIS_DIR = Path(__file__).parent + + +class TestVisionModelInference: + @pytest.mark.asyncio + async def test_vision_chat_completion_non_streaming( + self, inference_model, inference_stack + ): + inference_impl, _ = inference_stack + + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::together", + "remote::fireworks", + "remote::ollama", + ): + pytest.skip( + "Other inference providers don't support vision chat completion() yet" + ) + + images = [ + ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")), + ImageMedia( + image=URL( + uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" + ) + ), + ] + + # These are a bit hit-and-miss, need to be careful + expected_strings_to_check = [ + ["spaghetti"], + ["puppy"], + ] + for image, expected_strings in zip(images, expected_strings_to_check): + response = await inference_impl.chat_completion( + model=inference_model, + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage( + content=[image, "Describe this image in two sentences."] + ), + ], + stream=False, + ) + + assert isinstance(response, ChatCompletionResponse) + assert response.completion_message.role == "assistant" + assert isinstance(response.completion_message.content, str) + for expected_string in expected_strings: + assert expected_string in response.completion_message.content + + @pytest.mark.asyncio + async def test_vision_chat_completion_streaming( + self, inference_model, inference_stack + ): + inference_impl, _ = inference_stack + + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::together", + "remote::fireworks", + "remote::ollama", + ): + pytest.skip( + "Other inference providers don't support vision chat completion() yet" + ) + + images = [ + ImageMedia( + image=URL( + uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" + ) + ), + ] + expected_strings_to_check = [ + ["puppy"], + ] + for image, expected_strings in zip(images, expected_strings_to_check): + response = [ + r + async for r in await inference_impl.chat_completion( + model=inference_model, + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage( + content=[image, "Describe this image in two sentences."] + ), + ], + stream=True, + ) + ] + + assert len(response) > 0 + assert all( + isinstance(chunk, ChatCompletionResponseStreamChunk) + for chunk in response + ) + grouped = group_chunks(response) + assert len(grouped[ChatCompletionResponseEventType.start]) == 1 + assert len(grouped[ChatCompletionResponseEventType.progress]) > 0 + assert len(grouped[ChatCompletionResponseEventType.complete]) == 1 + + content = "".join( + chunk.event.delta + for chunk in grouped[ChatCompletionResponseEventType.progress] + ) + for expected_string in expected_strings: + assert expected_string in content diff --git a/llama_stack/providers/tests/inference/utils.py b/llama_stack/providers/tests/inference/utils.py new file mode 100644 index 000000000..aa8d377e9 --- /dev/null +++ b/llama_stack/providers/tests/inference/utils.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import itertools + + +def group_chunks(response): + return { + event_type: list(group) + for event_type, group in itertools.groupby( + response, key=lambda chunk: chunk.event.event_type + ) + } diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index adeab8476..b30e0fae4 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -5,16 +5,18 @@ # the root directory of this source tree. import os +import tempfile import pytest import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.adapters.memory.pgvector import PGVectorConfig -from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig -from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig +from llama_stack.providers.inline.meta_reference.memory import FaissImplConfig +from llama_stack.providers.remote.memory.pgvector import PGVectorConfig +from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail @@ -26,12 +28,15 @@ def memory_remote() -> ProviderFixture: @pytest.fixture(scope="session") def memory_meta_reference() -> ProviderFixture: + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") return ProviderFixture( providers=[ Provider( provider_id="meta-reference", provider_type="meta-reference", - config=FaissImplConfig().model_dump(), + config=FaissImplConfig( + kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(), + ).model_dump(), ) ], ) diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 74f8ef503..4789558ff 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -8,11 +8,11 @@ import pytest import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.adapters.safety.together import TogetherSafetyConfig -from llama_stack.providers.impls.meta_reference.safety import ( +from llama_stack.providers.inline.meta_reference.safety import ( LlamaGuardShieldConfig, SafetyConfig, ) +from llama_stack.providers.remote.safety.together import TogetherSafetyConfig from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 diff --git a/llama_stack/providers/utils/bedrock/client.py b/llama_stack/providers/utils/bedrock/client.py new file mode 100644 index 000000000..77781c729 --- /dev/null +++ b/llama_stack/providers/utils/bedrock/client.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import boto3 +from botocore.client import BaseClient +from botocore.config import Config + +from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig +from llama_stack.providers.utils.bedrock.refreshable_boto_session import ( + RefreshableBotoSession, +) + + +def create_bedrock_client( + config: BedrockBaseConfig, service_name: str = "bedrock-runtime" +) -> BaseClient: + """Creates a boto3 client for Bedrock services with the given configuration. + + Args: + config: The Bedrock configuration containing AWS credentials and settings + service_name: The AWS service name to create client for (default: "bedrock-runtime") + + Returns: + A configured boto3 client + """ + if config.aws_access_key_id and config.aws_secret_access_key: + retries_config = { + k: v + for k, v in dict( + total_max_attempts=config.total_max_attempts, + mode=config.retry_mode, + ).items() + if v is not None + } + + config_args = { + k: v + for k, v in dict( + region_name=config.region_name, + retries=retries_config if retries_config else None, + connect_timeout=config.connect_timeout, + read_timeout=config.read_timeout, + ).items() + if v is not None + } + + boto3_config = Config(**config_args) + + session_args = { + "aws_access_key_id": config.aws_access_key_id, + "aws_secret_access_key": config.aws_secret_access_key, + "aws_session_token": config.aws_session_token, + "region_name": config.region_name, + "profile_name": config.profile_name, + "session_ttl": config.session_ttl, + } + + # Remove None values + session_args = {k: v for k, v in session_args.items() if v is not None} + + boto3_session = boto3.session.Session(**session_args) + return boto3_session.client(service_name, config=boto3_config) + else: + return ( + RefreshableBotoSession( + region_name=config.region_name, + profile_name=config.profile_name, + session_ttl=config.session_ttl, + ) + .refreshable_session() + .client(service_name) + ) diff --git a/llama_stack/providers/adapters/inference/bedrock/config.py b/llama_stack/providers/utils/bedrock/config.py similarity index 90% rename from llama_stack/providers/adapters/inference/bedrock/config.py rename to llama_stack/providers/utils/bedrock/config.py index 72d2079b9..55c5582a1 100644 --- a/llama_stack/providers/adapters/inference/bedrock/config.py +++ b/llama_stack/providers/utils/bedrock/config.py @@ -1,55 +1,59 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import * # noqa: F403 - -from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field - - -@json_schema_type -class BedrockConfig(BaseModel): - aws_access_key_id: Optional[str] = Field( - default=None, - description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", - ) - aws_secret_access_key: Optional[str] = Field( - default=None, - description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY", - ) - aws_session_token: Optional[str] = Field( - default=None, - description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN", - ) - region_name: Optional[str] = Field( - default=None, - description="The default AWS Region to use, for example, us-west-1 or us-west-2." - "Default use environment variable: AWS_DEFAULT_REGION", - ) - profile_name: Optional[str] = Field( - default=None, - description="The profile name that contains credentials to use." - "Default use environment variable: AWS_PROFILE", - ) - total_max_attempts: Optional[int] = Field( - default=None, - description="An integer representing the maximum number of attempts that will be made for a single request, " - "including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS", - ) - retry_mode: Optional[str] = Field( - default=None, - description="A string representing the type of retries Boto3 will perform." - "Default use environment variable: AWS_RETRY_MODE", - ) - connect_timeout: Optional[float] = Field( - default=60, - description="The time in seconds till a timeout exception is thrown when attempting to make a connection. " - "The default is 60 seconds.", - ) - read_timeout: Optional[float] = Field( - default=60, - description="The time in seconds till a timeout exception is thrown when attempting to read from a connection." - "The default is 60 seconds.", - ) +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class BedrockBaseConfig(BaseModel): + aws_access_key_id: Optional[str] = Field( + default=None, + description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", + ) + aws_secret_access_key: Optional[str] = Field( + default=None, + description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY", + ) + aws_session_token: Optional[str] = Field( + default=None, + description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN", + ) + region_name: Optional[str] = Field( + default=None, + description="The default AWS Region to use, for example, us-west-1 or us-west-2." + "Default use environment variable: AWS_DEFAULT_REGION", + ) + profile_name: Optional[str] = Field( + default=None, + description="The profile name that contains credentials to use." + "Default use environment variable: AWS_PROFILE", + ) + total_max_attempts: Optional[int] = Field( + default=None, + description="An integer representing the maximum number of attempts that will be made for a single request, " + "including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS", + ) + retry_mode: Optional[str] = Field( + default=None, + description="A string representing the type of retries Boto3 will perform." + "Default use environment variable: AWS_RETRY_MODE", + ) + connect_timeout: Optional[float] = Field( + default=60, + description="The time in seconds till a timeout exception is thrown when attempting to make a connection. " + "The default is 60 seconds.", + ) + read_timeout: Optional[float] = Field( + default=60, + description="The time in seconds till a timeout exception is thrown when attempting to read from a connection." + "The default is 60 seconds.", + ) + session_ttl: Optional[int] = Field( + default=3600, + description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).", + ) diff --git a/llama_stack/providers/utils/bedrock/refreshable_boto_session.py b/llama_stack/providers/utils/bedrock/refreshable_boto_session.py new file mode 100644 index 000000000..f37563930 --- /dev/null +++ b/llama_stack/providers/utils/bedrock/refreshable_boto_session.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import datetime +from time import time +from uuid import uuid4 + +from boto3 import Session +from botocore.credentials import RefreshableCredentials +from botocore.session import get_session + + +class RefreshableBotoSession: + """ + Boto Helper class which lets us create a refreshable session so that we can cache the client or resource. + + Usage + ----- + session = RefreshableBotoSession().refreshable_session() + + client = session.client("s3") # we now can cache this client object without worrying about expiring credentials + """ + + def __init__( + self, + region_name: str = None, + profile_name: str = None, + sts_arn: str = None, + session_name: str = None, + session_ttl: int = 30000, + ): + """ + Initialize `RefreshableBotoSession` + + Parameters + ---------- + region_name : str (optional) + Default region when creating a new connection. + + profile_name : str (optional) + The name of a profile to use. + + sts_arn : str (optional) + The role arn to sts before creating a session. + + session_name : str (optional) + An identifier for the assumed role session. (required when `sts_arn` is given) + + session_ttl : int (optional) + An integer number to set the TTL for each session. Beyond this session, it will renew the token. + 50 minutes by default which is before the default role expiration of 1 hour + """ + + self.region_name = region_name + self.profile_name = profile_name + self.sts_arn = sts_arn + self.session_name = session_name or uuid4().hex + self.session_ttl = session_ttl + + def __get_session_credentials(self): + """ + Get session credentials + """ + session = Session(region_name=self.region_name, profile_name=self.profile_name) + + # if sts_arn is given, get credential by assuming the given role + if self.sts_arn: + sts_client = session.client( + service_name="sts", region_name=self.region_name + ) + response = sts_client.assume_role( + RoleArn=self.sts_arn, + RoleSessionName=self.session_name, + DurationSeconds=self.session_ttl, + ).get("Credentials") + + credentials = { + "access_key": response.get("AccessKeyId"), + "secret_key": response.get("SecretAccessKey"), + "token": response.get("SessionToken"), + "expiry_time": response.get("Expiration").isoformat(), + } + else: + session_credentials = session.get_credentials().get_frozen_credentials() + credentials = { + "access_key": session_credentials.access_key, + "secret_key": session_credentials.secret_key, + "token": session_credentials.token, + "expiry_time": datetime.datetime.fromtimestamp( + time() + self.session_ttl, datetime.timezone.utc + ).isoformat(), + } + + return credentials + + def refreshable_session(self) -> Session: + """ + Get refreshable boto3 session. + """ + # Get refreshable credentials + refreshable_credentials = RefreshableCredentials.create_from_metadata( + metadata=self.__get_session_credentials(), + refresh_using=self.__get_session_credentials, + method="sts-assume-role", + ) + + # attach refreshable credentials current session + session = get_session() + session._credentials = refreshable_credentials + session.set_config_variable("region", self.region_name) + autorefresh_session = Session(botocore_session=session) + + return autorefresh_session diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 086227c73..cc3e7a2ce 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -46,6 +46,9 @@ def text_from_choice(choice) -> str: if hasattr(choice, "delta") and choice.delta: return choice.delta.content + if hasattr(choice, "message"): + return choice.message.content + return choice.text @@ -99,7 +102,6 @@ def process_chat_completion_response( async def process_completion_stream_response( stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat ) -> AsyncGenerator: - stop_reason = None async for chunk in stream: @@ -158,6 +160,10 @@ async def process_chat_completion_stream_response( break text = text_from_choice(choice) + if not text: + # Sometimes you get empty chunks from providers + continue + # check if its a tool call ( aka starts with <|python_tag|> ) if not ipython and text.startswith("<|python_tag|>"): ipython = True diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 386146ed9..9decf5a00 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -3,10 +3,16 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +import base64 +import io import json from typing import Tuple +import httpx + from llama_models.llama3.api.chat_format import ChatFormat +from PIL import Image as PIL_Image from termcolor import cprint from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -24,6 +30,90 @@ from llama_models.sku_list import resolve_model from llama_stack.providers.utils.inference import supported_inference_models +def content_has_media(content: InterleavedTextMedia): + def _has_media_content(c): + return isinstance(c, ImageMedia) + + if isinstance(content, list): + return any(_has_media_content(c) for c in content) + else: + return _has_media_content(content) + + +def messages_have_media(messages: List[Message]): + return any(content_has_media(m.content) for m in messages) + + +def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]): + if isinstance(request, ChatCompletionRequest): + return messages_have_media(request.messages) + else: + return content_has_media(request.content) + + +async def convert_image_media_to_url( + media: ImageMedia, download: bool = False, include_format: bool = True +) -> str: + if isinstance(media.image, PIL_Image.Image): + if media.image.format == "PNG": + format = "png" + elif media.image.format == "GIF": + format = "gif" + elif media.image.format == "JPEG": + format = "jpeg" + else: + raise ValueError(f"Unsupported image format {media.image.format}") + + bytestream = io.BytesIO() + media.image.save(bytestream, format=media.image.format) + bytestream.seek(0) + content = bytestream.getvalue() + else: + if not download: + return media.image.uri + else: + assert isinstance(media.image, URL) + async with httpx.AsyncClient() as client: + r = await client.get(media.image.uri) + content = r.content + content_type = r.headers.get("content-type") + if content_type: + format = content_type.split("/")[-1] + else: + format = "png" + + if include_format: + return f"data:image/{format};base64," + base64.b64encode(content).decode( + "utf-8" + ) + else: + return base64.b64encode(content).decode("utf-8") + + +async def convert_message_to_dict(message: Message) -> dict: + async def _convert_content(content) -> dict: + if isinstance(content, ImageMedia): + return { + "type": "image_url", + "image_url": { + "url": await convert_image_media_to_url(content), + }, + } + else: + assert isinstance(content, str) + return {"type": "text", "text": content} + + if isinstance(message.content, list): + content = [await _convert_content(c) for c in message.content] + else: + content = [await _convert_content(message.content)] + + return { + "role": message.role, + "content": content, + } + + def completion_request_to_prompt( request: CompletionRequest, formatter: ChatFormat ) -> str: diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index c84212eed..0a21bf4ca 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -4,10 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import re from enum import Enum from typing import Literal, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR @@ -51,6 +52,23 @@ class PostgresKVStoreConfig(CommonConfig): db: str = "llamastack" user: str password: Optional[str] = None + table_name: str = "llamastack_kvstore" + + @classmethod + @field_validator("table_name") + def validate_table_name(cls, v: str) -> str: + # PostgreSQL identifiers rules: + # - Must start with a letter or underscore + # - Can contain letters, numbers, and underscores + # - Maximum length is 63 bytes + pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$" + if not re.match(pattern, v): + raise ValueError( + "Invalid table name. Must start with letter or underscore and contain only letters, numbers, and underscores" + ) + if len(v) > 63: + raise ValueError("Table name must be less than 63 characters") + return v KVStoreConfig = Annotated[ diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py index a3cabc206..469f400d0 100644 --- a/llama_stack/providers/utils/kvstore/kvstore.py +++ b/llama_stack/providers/utils/kvstore/kvstore.py @@ -43,7 +43,9 @@ async def kvstore_impl(config: KVStoreConfig) -> KVStore: impl = SqliteKVStoreImpl(config) elif config.type == KVStoreType.postgres.value: - raise NotImplementedError() + from .postgres import PostgresKVStoreImpl + + impl = PostgresKVStoreImpl(config) else: raise ValueError(f"Unknown kvstore type {config.type}") diff --git a/llama_stack/providers/utils/kvstore/postgres/__init__.py b/llama_stack/providers/utils/kvstore/postgres/__init__.py new file mode 100644 index 000000000..efbf6299d --- /dev/null +++ b/llama_stack/providers/utils/kvstore/postgres/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .postgres import PostgresKVStoreImpl # noqa: F401 F403 diff --git a/llama_stack/providers/utils/kvstore/postgres/postgres.py b/llama_stack/providers/utils/kvstore/postgres/postgres.py new file mode 100644 index 000000000..23ceb58e4 --- /dev/null +++ b/llama_stack/providers/utils/kvstore/postgres/postgres.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime +from typing import List, Optional + +import psycopg2 +from psycopg2.extras import DictCursor + +from ..api import KVStore +from ..config import PostgresKVStoreConfig + + +class PostgresKVStoreImpl(KVStore): + def __init__(self, config: PostgresKVStoreConfig): + self.config = config + self.conn = None + self.cursor = None + + async def initialize(self) -> None: + try: + self.conn = psycopg2.connect( + host=self.config.host, + port=self.config.port, + database=self.config.db, + user=self.config.user, + password=self.config.password, + ) + self.conn.autocommit = True + self.cursor = self.conn.cursor(cursor_factory=DictCursor) + + # Create table if it doesn't exist + self.cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.config.table_name} ( + key TEXT PRIMARY KEY, + value TEXT, + expiration TIMESTAMP + ) + """ + ) + except Exception as e: + import traceback + + traceback.print_exc() + raise RuntimeError("Could not connect to PostgreSQL database server") from e + + def _namespaced_key(self, key: str) -> str: + if not self.config.namespace: + return key + return f"{self.config.namespace}:{key}" + + async def set( + self, key: str, value: str, expiration: Optional[datetime] = None + ) -> None: + key = self._namespaced_key(key) + self.cursor.execute( + f""" + INSERT INTO {self.config.table_name} (key, value, expiration) + VALUES (%s, %s, %s) + ON CONFLICT (key) DO UPDATE + SET value = EXCLUDED.value, expiration = EXCLUDED.expiration + """, + (key, value, expiration), + ) + + async def get(self, key: str) -> Optional[str]: + key = self._namespaced_key(key) + self.cursor.execute( + f""" + SELECT value FROM {self.config.table_name} + WHERE key = %s + AND (expiration IS NULL OR expiration > NOW()) + """, + (key,), + ) + result = self.cursor.fetchone() + return result[0] if result else None + + async def delete(self, key: str) -> None: + key = self._namespaced_key(key) + self.cursor.execute( + f"DELETE FROM {self.config.table_name} WHERE key = %s", + (key,), + ) + + async def range(self, start_key: str, end_key: str) -> List[str]: + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + + self.cursor.execute( + f""" + SELECT value FROM {self.config.table_name} + WHERE key >= %s AND key < %s + AND (expiration IS NULL OR expiration > NOW()) + ORDER BY key + """, + (start_key, end_key), + ) + return [row[0] for row in self.cursor.fetchall()] diff --git a/llama_stack/templates/fireworks/build.yaml b/llama_stack/templates/fireworks/build.yaml index 994e4c641..5b662c213 100644 --- a/llama_stack/templates/fireworks/build.yaml +++ b/llama_stack/templates/fireworks/build.yaml @@ -6,8 +6,6 @@ distribution_spec: memory: - meta-reference - remote::weaviate - - remote::chromadb - - remote::pgvector safety: meta-reference agents: meta-reference telemetry: meta-reference