mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-19 22:28:40 +00:00
Merge branch 'main' into vllm_health_check
This commit is contained in:
commit
c18b585d32
143 changed files with 9210 additions and 5347 deletions
2
.github/workflows/providers-build.yml
vendored
2
.github/workflows/providers-build.yml
vendored
|
|
@ -10,6 +10,7 @@ on:
|
||||||
- 'llama_stack/distribution/build.*'
|
- 'llama_stack/distribution/build.*'
|
||||||
- 'llama_stack/distribution/*.sh'
|
- 'llama_stack/distribution/*.sh'
|
||||||
- '.github/workflows/providers-build.yml'
|
- '.github/workflows/providers-build.yml'
|
||||||
|
- 'llama_stack/templates/**'
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- 'llama_stack/cli/stack/build.py'
|
- 'llama_stack/cli/stack/build.py'
|
||||||
|
|
@ -17,6 +18,7 @@ on:
|
||||||
- 'llama_stack/distribution/build.*'
|
- 'llama_stack/distribution/build.*'
|
||||||
- 'llama_stack/distribution/*.sh'
|
- 'llama_stack/distribution/*.sh'
|
||||||
- '.github/workflows/providers-build.yml'
|
- '.github/workflows/providers-build.yml'
|
||||||
|
- 'llama_stack/templates/**'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ source .venv/bin/activate
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> You can pin a specific version of Python to use for `uv` by adding a `.python-version` file in the root project directory.
|
> You can use a specific version of Python with `uv` by adding the `--python <version>` flag (e.g. `--python 3.11`)
|
||||||
> Otherwise, `uv` will automatically select a Python version according to the `requires-python` section of the `pyproject.toml`.
|
> Otherwise, `uv` will automatically select a Python version according to the `requires-python` section of the `pyproject.toml`.
|
||||||
> For more info, see the [uv docs around Python versions](https://docs.astral.sh/uv/concepts/python-versions/).
|
> For more info, see the [uv docs around Python versions](https://docs.astral.sh/uv/concepts/python-versions/).
|
||||||
|
|
||||||
|
|
|
||||||
1386
docs/_static/llama-stack-spec.html
vendored
1386
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1047
docs/_static/llama-stack-spec.yaml
vendored
1047
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
|
@ -30,6 +30,9 @@ from llama_stack.strong_typing.schema import (
|
||||||
Schema,
|
Schema,
|
||||||
SchemaOptions,
|
SchemaOptions,
|
||||||
)
|
)
|
||||||
|
from typing import get_origin, get_args
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi import UploadFile
|
||||||
from llama_stack.strong_typing.serialization import json_dump_string, object_to_json
|
from llama_stack.strong_typing.serialization import json_dump_string, object_to_json
|
||||||
|
|
||||||
from .operations import (
|
from .operations import (
|
||||||
|
|
@ -618,6 +621,45 @@ class Generator:
|
||||||
},
|
},
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
|
# data passed in request body as multipart/form-data
|
||||||
|
elif op.multipart_params:
|
||||||
|
builder = ContentBuilder(self.schema_builder)
|
||||||
|
|
||||||
|
# Create schema properties for multipart form fields
|
||||||
|
properties = {}
|
||||||
|
required_fields = []
|
||||||
|
|
||||||
|
for name, param_type in op.multipart_params:
|
||||||
|
if get_origin(param_type) is Annotated:
|
||||||
|
base_type = get_args(param_type)[0]
|
||||||
|
else:
|
||||||
|
base_type = param_type
|
||||||
|
if base_type is UploadFile:
|
||||||
|
# File upload
|
||||||
|
properties[name] = {
|
||||||
|
"type": "string",
|
||||||
|
"format": "binary"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Form field
|
||||||
|
properties[name] = self.schema_builder.classdef_to_ref(base_type)
|
||||||
|
|
||||||
|
required_fields.append(name)
|
||||||
|
|
||||||
|
multipart_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
"required": required_fields
|
||||||
|
}
|
||||||
|
|
||||||
|
requestBody = RequestBody(
|
||||||
|
content={
|
||||||
|
"multipart/form-data": {
|
||||||
|
"schema": multipart_schema
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
# data passed in payload as JSON and mapped to request parameters
|
# data passed in payload as JSON and mapped to request parameters
|
||||||
elif op.request_params:
|
elif op.request_params:
|
||||||
builder = ContentBuilder(self.schema_builder)
|
builder = ContentBuilder(self.schema_builder)
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,12 @@ from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.strong_typing.inspection import get_signature
|
from llama_stack.strong_typing.inspection import get_signature
|
||||||
|
|
||||||
|
from typing import get_origin, get_args
|
||||||
|
|
||||||
|
from fastapi import UploadFile
|
||||||
|
from fastapi.params import File, Form
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
|
||||||
def split_prefix(
|
def split_prefix(
|
||||||
s: str, sep: str, prefix: Union[str, Iterable[str]]
|
s: str, sep: str, prefix: Union[str, Iterable[str]]
|
||||||
|
|
@ -82,6 +88,7 @@ class EndpointOperation:
|
||||||
:param path_params: Parameters of the operation signature that are passed in the path component of the URL string.
|
:param path_params: Parameters of the operation signature that are passed in the path component of the URL string.
|
||||||
:param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs.
|
:param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs.
|
||||||
:param request_params: The parameter that corresponds to the data transmitted in the request body.
|
:param request_params: The parameter that corresponds to the data transmitted in the request body.
|
||||||
|
:param multipart_params: Parameters that indicate multipart/form-data request body.
|
||||||
:param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress.
|
:param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress.
|
||||||
:param response_type: The Python type of the data that is transmitted in the response body.
|
:param response_type: The Python type of the data that is transmitted in the response body.
|
||||||
:param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT.
|
:param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT.
|
||||||
|
|
@ -98,6 +105,7 @@ class EndpointOperation:
|
||||||
path_params: List[OperationParameter]
|
path_params: List[OperationParameter]
|
||||||
query_params: List[OperationParameter]
|
query_params: List[OperationParameter]
|
||||||
request_params: Optional[OperationParameter]
|
request_params: Optional[OperationParameter]
|
||||||
|
multipart_params: List[OperationParameter]
|
||||||
event_type: Optional[type]
|
event_type: Optional[type]
|
||||||
response_type: type
|
response_type: type
|
||||||
http_method: HTTPMethod
|
http_method: HTTPMethod
|
||||||
|
|
@ -252,6 +260,7 @@ def get_endpoint_operations(
|
||||||
path_params = []
|
path_params = []
|
||||||
query_params = []
|
query_params = []
|
||||||
request_params = []
|
request_params = []
|
||||||
|
multipart_params = []
|
||||||
|
|
||||||
for param_name, parameter in signature.parameters.items():
|
for param_name, parameter in signature.parameters.items():
|
||||||
param_type = _get_annotation_type(parameter.annotation, func_ref)
|
param_type = _get_annotation_type(parameter.annotation, func_ref)
|
||||||
|
|
@ -266,6 +275,8 @@ def get_endpoint_operations(
|
||||||
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
|
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_multipart = _is_multipart_param(param_type)
|
||||||
|
|
||||||
if prefix in ["get", "delete"]:
|
if prefix in ["get", "delete"]:
|
||||||
if route_params is not None and param_name in route_params:
|
if route_params is not None and param_name in route_params:
|
||||||
path_params.append((param_name, param_type))
|
path_params.append((param_name, param_type))
|
||||||
|
|
@ -274,6 +285,8 @@ def get_endpoint_operations(
|
||||||
else:
|
else:
|
||||||
if route_params is not None and param_name in route_params:
|
if route_params is not None and param_name in route_params:
|
||||||
path_params.append((param_name, param_type))
|
path_params.append((param_name, param_type))
|
||||||
|
elif is_multipart:
|
||||||
|
multipart_params.append((param_name, param_type))
|
||||||
else:
|
else:
|
||||||
request_params.append((param_name, param_type))
|
request_params.append((param_name, param_type))
|
||||||
|
|
||||||
|
|
@ -333,6 +346,7 @@ def get_endpoint_operations(
|
||||||
path_params=path_params,
|
path_params=path_params,
|
||||||
query_params=query_params,
|
query_params=query_params,
|
||||||
request_params=request_params,
|
request_params=request_params,
|
||||||
|
multipart_params=multipart_params,
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
http_method=http_method,
|
http_method=http_method,
|
||||||
|
|
@ -377,3 +391,34 @@ def get_endpoint_events(endpoint: type) -> Dict[str, type]:
|
||||||
results[param_type.__name__] = param_type
|
results[param_type.__name__] = param_type
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _is_multipart_param(param_type: type) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a parameter type indicates multipart form data.
|
||||||
|
|
||||||
|
Returns True if the type is:
|
||||||
|
- UploadFile
|
||||||
|
- Annotated[UploadFile, File()]
|
||||||
|
- Annotated[str, Form()]
|
||||||
|
- Annotated[Any, File()]
|
||||||
|
- Annotated[Any, Form()]
|
||||||
|
"""
|
||||||
|
if param_type is UploadFile:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for Annotated types
|
||||||
|
origin = get_origin(param_type)
|
||||||
|
if origin is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if origin is Annotated:
|
||||||
|
args = get_args(param_type)
|
||||||
|
if len(args) < 2:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check the annotations for File() or Form()
|
||||||
|
for annotation in args[1:]:
|
||||||
|
if isinstance(annotation, (File, Form)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
|
||||||
|
|
@ -153,6 +153,12 @@ def _validate_api_delete_method_returns_none(method) -> str | None:
|
||||||
return "has no return type annotation"
|
return "has no return type annotation"
|
||||||
|
|
||||||
return_type = hints['return']
|
return_type = hints['return']
|
||||||
|
|
||||||
|
# Allow OpenAI endpoints to return response objects since they follow OpenAI specification
|
||||||
|
method_name = getattr(method, '__name__', '')
|
||||||
|
if method_name.startswith('openai_'):
|
||||||
|
return None
|
||||||
|
|
||||||
if return_type is not None and return_type is not type(None):
|
if return_type is not None and return_type is not type(None):
|
||||||
return "does not return None where None is mandatory"
|
return "does not return None where None is mandatory"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,29 +9,24 @@ When instantiating an agent, you can provide it a list of tool groups that it ha
|
||||||
|
|
||||||
Refer to the [Building AI Applications](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb) notebook for more examples on how to use tools.
|
Refer to the [Building AI Applications](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb) notebook for more examples on how to use tools.
|
||||||
|
|
||||||
## Types of Tool Group providers
|
## Server-side vs. client-side tool execution
|
||||||
|
|
||||||
There are three types of providers for tool groups that are supported by Llama Stack.
|
Llama Stack allows you to use both server-side and client-side tools. With server-side tools, `agent.create_turn` can perform execution of the tool calls emitted by the model
|
||||||
|
transparently giving the user the final answer desired. If client-side tools are provided, the tool call is sent back to the user for execution
|
||||||
|
and optional continuation using the `agent.resume_turn` method.
|
||||||
|
|
||||||
1. Built-in providers
|
|
||||||
2. Model Context Protocol (MCP) providers
|
|
||||||
3. Client provided tools
|
|
||||||
|
|
||||||
### Built-in providers
|
### Server-side tools
|
||||||
|
|
||||||
Built-in providers come packaged with Llama Stack. These providers provide common functionalities like web search, code interpretation, and computational capabilities.
|
Llama Stack provides built-in providers for some common tools. These include web search, math, and RAG capabilities.
|
||||||
|
|
||||||
#### Web Search providers
|
#### Web Search
|
||||||
There are three web search providers that are supported by Llama Stack.
|
|
||||||
|
|
||||||
1. Brave Search
|
You have three providers to execute the web search tool calls generated by a model: Brave Search, Bing Search, and Tavily Search.
|
||||||
2. Bing Search
|
|
||||||
3. Tavily Search
|
|
||||||
|
|
||||||
Example client SDK call to register a "websearch" toolgroup that is provided by brave-search.
|
To indicate that the web search tool calls should be executed by brave-search, you can point the "builtin::websearch" toolgroup to the "brave-search" provider.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Register Brave Search tool group
|
|
||||||
client.toolgroups.register(
|
client.toolgroups.register(
|
||||||
toolgroup_id="builtin::websearch",
|
toolgroup_id="builtin::websearch",
|
||||||
provider_id="brave-search",
|
provider_id="brave-search",
|
||||||
|
|
@ -39,17 +34,17 @@ client.toolgroups.register(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
The tool requires an API key which can be provided either in the configuration or through the request header `X-LlamaStack-Provider-Data`. The format of the header is `{"<provider_name>_api_key": <your api key>}`.
|
The tool requires an API key which can be provided either in the configuration or through the request header `X-LlamaStack-Provider-Data`. The format of the header is:
|
||||||
|
```
|
||||||
> **NOTE:** When using Tavily Search and Bing Search, the inference output will still display "Brave Search." This is because Llama models have been trained with Brave Search as a built-in tool. Tavily and bing is just being used in lieu of Brave search.
|
{"<provider_name>_api_key": <your api key>}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
#### WolframAlpha
|
#### Math
|
||||||
|
|
||||||
The WolframAlpha tool provides access to computational knowledge through the WolframAlpha API.
|
The WolframAlpha tool provides access to computational knowledge through the WolframAlpha API.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Register WolframAlpha tool group
|
|
||||||
client.toolgroups.register(
|
client.toolgroups.register(
|
||||||
toolgroup_id="builtin::wolfram_alpha", provider_id="wolfram-alpha"
|
toolgroup_id="builtin::wolfram_alpha", provider_id="wolfram-alpha"
|
||||||
)
|
)
|
||||||
|
|
@ -83,11 +78,49 @@ Features:
|
||||||
|
|
||||||
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers.
|
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers.
|
||||||
|
|
||||||
## Model Context Protocol (MCP) Tools
|
## Model Context Protocol (MCP)
|
||||||
|
|
||||||
MCP tools are special tools that can interact with llama stack over model context protocol. These tools are dynamically discovered from an MCP endpoint and can be used to extend the agent's capabilities.
|
[MCP](https://github.com/modelcontextprotocol) is an upcoming, popular standard for tool discovery and execution. It is a protocol that allows tools to be dynamically discovered
|
||||||
|
from an MCP endpoint and can be used to extend the agent's capabilities.
|
||||||
|
|
||||||
Refer to [https://github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) for available MCP servers.
|
|
||||||
|
### Using Remote MCP Servers
|
||||||
|
|
||||||
|
You can find some popular remote MCP servers [here](https://github.com/jaw9c/awesome-remote-mcp-servers). You can register them as toolgroups in the same way as local providers.
|
||||||
|
|
||||||
|
```python
|
||||||
|
client.toolgroups.register(
|
||||||
|
toolgroup_id="mcp::deepwiki",
|
||||||
|
provider_id="model-context-protocol",
|
||||||
|
mcp_endpoint=URL(uri="https://mcp.deepwiki.com/sse"),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that most of the more useful MCP servers need you to authenticate with them. Many of them use OAuth2.0 for authentication. You can provide authorization headers to send to the MCP server
|
||||||
|
using the "Provider Data" abstraction provided by Llama Stack. When making an agent call,
|
||||||
|
|
||||||
|
```python
|
||||||
|
agent = Agent(
|
||||||
|
...,
|
||||||
|
tools=["mcp::deepwiki"],
|
||||||
|
extra_headers={
|
||||||
|
"X-LlamaStack-Provider-Data": json.dumps(
|
||||||
|
{
|
||||||
|
"mcp_headers": {
|
||||||
|
"http://mcp.deepwiki.com/sse": {
|
||||||
|
"Authorization": "Bearer <your_access_token>",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
agent.create_turn(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running your own MCP server
|
||||||
|
|
||||||
|
Here's an example of how to run a simple MCP server that exposes a File System as a set of tools to the Llama Stack agent.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
# start your MCP server
|
# start your MCP server
|
||||||
|
|
@ -106,13 +139,9 @@ client.toolgroups.register(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
MCP tools require:
|
|
||||||
- A valid MCP endpoint URL
|
|
||||||
- The endpoint must implement the Model Context Protocol
|
|
||||||
- Tools are discovered dynamically from the endpoint
|
|
||||||
|
|
||||||
|
|
||||||
## Adding Custom Tools
|
## Adding Custom (Client-side) Tools
|
||||||
|
|
||||||
When you want to use tools other than the built-in tools, you just need to implement a python function with a docstring. The content of the docstring will be used to describe the tool and the parameters and passed
|
When you want to use tools other than the built-in tools, you just need to implement a python function with a docstring. The content of the docstring will be used to describe the tool and the parameters and passed
|
||||||
along to the generative model.
|
along to the generative model.
|
||||||
|
|
|
||||||
12
docs/source/concepts/api_providers.md
Normal file
12
docs/source/concepts/api_providers.md
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
## API Providers
|
||||||
|
|
||||||
|
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
||||||
|
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
||||||
|
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
|
||||||
|
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
||||||
|
|
||||||
|
Providers come in two flavors:
|
||||||
|
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
|
||||||
|
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
|
||||||
|
|
||||||
|
Most importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
||||||
18
docs/source/concepts/apis.md
Normal file
18
docs/source/concepts/apis.md
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
## APIs
|
||||||
|
|
||||||
|
A Llama Stack API is described as a collection of REST endpoints. We currently support the following APIs:
|
||||||
|
|
||||||
|
- **Inference**: run inference with a LLM
|
||||||
|
- **Safety**: apply safety policies to the output at a Systems (not only model) level
|
||||||
|
- **Agents**: run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
|
||||||
|
- **DatasetIO**: interface with datasets and data loaders
|
||||||
|
- **Scoring**: evaluate outputs of the system
|
||||||
|
- **Eval**: generate outputs (via Inference or Agents) and perform scoring
|
||||||
|
- **VectorIO**: perform operations on vector stores, such as adding documents, searching, and deleting documents
|
||||||
|
- **Telemetry**: collect telemetry data from the system
|
||||||
|
|
||||||
|
We are working on adding a few more APIs to complete the application lifecycle. These will include:
|
||||||
|
- **Batch Inference**: run inference on a dataset of inputs
|
||||||
|
- **Batch Agents**: run agents on a dataset of inputs
|
||||||
|
- **Post Training**: fine-tune a Llama model
|
||||||
|
- **Synthetic Data Generation**: generate synthetic data for model development
|
||||||
9
docs/source/concepts/distributions.md
Normal file
9
docs/source/concepts/distributions.md
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
## Distributions
|
||||||
|
|
||||||
|
While there is a lot of flexibility to mix-and-match providers, often users will work with a specific set of providers (hardware support, contractual obligations, etc.) We therefore need to provide a _convenient shorthand_ for such collections. We call this shorthand a **Llama Stack Distribution** or a **Distro**. One can think of it as specific pre-packaged versions of the Llama Stack. Here are some examples:
|
||||||
|
|
||||||
|
**Remotely Hosted Distro**: These are the simplest to consume from a user perspective. You can simply obtain the API key for these providers, point to a URL and have _all_ Llama Stack APIs working out of the box. Currently, [Fireworks](https://fireworks.ai/) and [Together](https://together.xyz/) provide such easy-to-consume Llama Stack distributions.
|
||||||
|
|
||||||
|
**Locally Hosted Distro**: You may want to run Llama Stack on your own hardware. Typically though, you still need to use Inference via an external service. You can use providers like HuggingFace TGI, Fireworks, Together, etc. for this purpose. Or you may have access to GPUs and can run a [vLLM](https://github.com/vllm-project/vllm) or [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) instance. If you "just" have a regular desktop machine, you can use [Ollama](https://ollama.com/) for inference. To provide convenient quick access to these options, we provide a number of such pre-configured locally-hosted Distros.
|
||||||
|
|
||||||
|
**On-device Distro**: To run Llama Stack directly on an edge device (mobile phone or a tablet), we provide Distros for [iOS](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/ios_sdk.html) and [Android](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/android_sdk.html)
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
# Evaluation Concepts
|
## Evaluation Concepts
|
||||||
|
|
||||||
The Llama Stack Evaluation flow allows you to run evaluations on your GenAI application datasets or pre-registered benchmarks.
|
The Llama Stack Evaluation flow allows you to run evaluations on your GenAI application datasets or pre-registered benchmarks.
|
||||||
|
|
||||||
|
|
@ -10,11 +10,7 @@ We introduce a set of APIs in Llama Stack for supporting running evaluations of
|
||||||
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
|
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
|
||||||
|
|
||||||
|
|
||||||
## Evaluation Concepts
|
The Evaluation APIs are associated with a set of Resources. Please visit the Resources section in our [Core Concepts](../concepts/index.md) guide for better high-level understanding.
|
||||||
|
|
||||||
The Evaluation APIs are associated with a set of Resources as shown in the following diagram. Please visit the Resources section in our [Core Concepts](../concepts/index.md) guide for better high-level understanding.
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
- **DatasetIO**: defines interface with datasets and data loaders.
|
- **DatasetIO**: defines interface with datasets and data loaders.
|
||||||
- Associated with `Dataset` resource.
|
- Associated with `Dataset` resource.
|
||||||
|
|
@ -24,9 +20,9 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
|
||||||
- Associated with `Benchmark` resource.
|
- Associated with `Benchmark` resource.
|
||||||
|
|
||||||
|
|
||||||
## Open-benchmark Eval
|
### Open-benchmark Eval
|
||||||
|
|
||||||
### List of open-benchmarks Llama Stack support
|
#### List of open-benchmarks Llama Stack support
|
||||||
|
|
||||||
Llama stack pre-registers several popular open-benchmarks to easily evaluate model perfomance via CLI.
|
Llama stack pre-registers several popular open-benchmarks to easily evaluate model perfomance via CLI.
|
||||||
|
|
||||||
|
|
@ -39,7 +35,7 @@ The list of open-benchmarks we currently support:
|
||||||
|
|
||||||
You can follow this [contributing guide](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) to add more open-benchmarks to Llama Stack
|
You can follow this [contributing guide](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) to add more open-benchmarks to Llama Stack
|
||||||
|
|
||||||
### Run evaluation on open-benchmarks via CLI
|
#### Run evaluation on open-benchmarks via CLI
|
||||||
|
|
||||||
We have built-in functionality to run the supported open-benckmarks using llama-stack-client CLI
|
We have built-in functionality to run the supported open-benckmarks using llama-stack-client CLI
|
||||||
|
|
||||||
|
|
@ -74,7 +70,7 @@ evaluation results over there.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## What's Next?
|
#### What's Next?
|
||||||
|
|
||||||
- Check out our Colab notebook on working examples with running benchmark evaluations [here](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb#scrollTo=mxLCsP4MvFqP).
|
- Check out our Colab notebook on working examples with running benchmark evaluations [here](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb#scrollTo=mxLCsP4MvFqP).
|
||||||
- Check out our [Building Applications - Evaluation](../building_applications/evals.md) guide for more details on how to use the Evaluation APIs to evaluate your applications.
|
- Check out our [Building Applications - Evaluation](../building_applications/evals.md) guide for more details on how to use the Evaluation APIs to evaluate your applications.
|
||||||
|
|
|
||||||
|
|
@ -1,74 +1,23 @@
|
||||||
# Core Concepts
|
# Core Concepts
|
||||||
|
|
||||||
|
|
||||||
```{toctree}
|
|
||||||
:maxdepth: 1
|
|
||||||
:hidden:
|
|
||||||
|
|
||||||
evaluation_concepts
|
|
||||||
```
|
|
||||||
|
|
||||||
Given Llama Stack's service-oriented philosophy, a few concepts and workflows arise which may not feel completely natural in the LLM landscape, especially if you are coming with a background in other frameworks.
|
Given Llama Stack's service-oriented philosophy, a few concepts and workflows arise which may not feel completely natural in the LLM landscape, especially if you are coming with a background in other frameworks.
|
||||||
|
|
||||||
|
```{include} apis.md
|
||||||
## APIs
|
:start-after: ## APIs
|
||||||
|
|
||||||
A Llama Stack API is described as a collection of REST endpoints. We currently support the following APIs:
|
|
||||||
|
|
||||||
- **Inference**: run inference with a LLM
|
|
||||||
- **Safety**: apply safety policies to the output at a Systems (not only model) level
|
|
||||||
- **Agents**: run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
|
|
||||||
- **DatasetIO**: interface with datasets and data loaders
|
|
||||||
- **Scoring**: evaluate outputs of the system
|
|
||||||
- **Eval**: generate outputs (via Inference or Agents) and perform scoring
|
|
||||||
- **VectorIO**: perform operations on vector stores, such as adding documents, searching, and deleting documents
|
|
||||||
- **Telemetry**: collect telemetry data from the system
|
|
||||||
|
|
||||||
We are working on adding a few more APIs to complete the application lifecycle. These will include:
|
|
||||||
- **Batch Inference**: run inference on a dataset of inputs
|
|
||||||
- **Batch Agents**: run agents on a dataset of inputs
|
|
||||||
- **Post Training**: fine-tune a Llama model
|
|
||||||
- **Synthetic Data Generation**: generate synthetic data for model development
|
|
||||||
|
|
||||||
## API Providers
|
|
||||||
|
|
||||||
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
|
||||||
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
|
||||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
|
|
||||||
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
|
||||||
|
|
||||||
Providers come in two flavors:
|
|
||||||
- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code.
|
|
||||||
- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack.
|
|
||||||
|
|
||||||
Most importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
|
||||||
## Resources
|
|
||||||
|
|
||||||
Some of these APIs are associated with a set of **Resources**. Here is the mapping of APIs to resources:
|
|
||||||
|
|
||||||
- **Inference**, **Eval** and **Post Training** are associated with `Model` resources.
|
|
||||||
- **Safety** is associated with `Shield` resources.
|
|
||||||
- **Tool Runtime** is associated with `ToolGroup` resources.
|
|
||||||
- **DatasetIO** is associated with `Dataset` resources.
|
|
||||||
- **VectorIO** is associated with `VectorDB` resources.
|
|
||||||
- **Scoring** is associated with `ScoringFunction` resources.
|
|
||||||
- **Eval** is associated with `Model` and `Benchmark` resources.
|
|
||||||
|
|
||||||
Furthermore, we allow these resources to be **federated** across multiple providers. For example, you may have some Llama models served by Fireworks while others are served by AWS Bedrock. Regardless, they will all work seamlessly with the same uniform Inference API provided by Llama Stack.
|
|
||||||
|
|
||||||
```{admonition} Registering Resources
|
|
||||||
:class: tip
|
|
||||||
|
|
||||||
Given this architecture, it is necessary for the Stack to know which provider to use for a given resource. This means you need to explicitly _register_ resources (including models) before you can use them with the associated APIs.
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Distributions
|
```{include} api_providers.md
|
||||||
|
:start-after: ## API Providers
|
||||||
|
```
|
||||||
|
|
||||||
While there is a lot of flexibility to mix-and-match providers, often users will work with a specific set of providers (hardware support, contractual obligations, etc.) We therefore need to provide a _convenient shorthand_ for such collections. We call this shorthand a **Llama Stack Distribution** or a **Distro**. One can think of it as specific pre-packaged versions of the Llama Stack. Here are some examples:
|
```{include} resources.md
|
||||||
|
:start-after: ## Resources
|
||||||
|
```
|
||||||
|
|
||||||
**Remotely Hosted Distro**: These are the simplest to consume from a user perspective. You can simply obtain the API key for these providers, point to a URL and have _all_ Llama Stack APIs working out of the box. Currently, [Fireworks](https://fireworks.ai/) and [Together](https://together.xyz/) provide such easy-to-consume Llama Stack distributions.
|
```{include} distributions.md
|
||||||
|
:start-after: ## Distributions
|
||||||
|
```
|
||||||
|
|
||||||
**Locally Hosted Distro**: You may want to run Llama Stack on your own hardware. Typically though, you still need to use Inference via an external service. You can use providers like HuggingFace TGI, Fireworks, Together, etc. for this purpose. Or you may have access to GPUs and can run a [vLLM](https://github.com/vllm-project/vllm) or [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) instance. If you "just" have a regular desktop machine, you can use [Ollama](https://ollama.com/) for inference. To provide convenient quick access to these options, we provide a number of such pre-configured locally-hosted Distros.
|
```{include} evaluation_concepts.md
|
||||||
|
:start-after: ## Evaluation Concepts
|
||||||
|
```
|
||||||
**On-device Distro**: To run Llama Stack directly on an edge device (mobile phone or a tablet), we provide Distros for [iOS](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/ios_sdk.html) and [Android](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/android_sdk.html)
|
|
||||||
|
|
|
||||||
19
docs/source/concepts/resources.md
Normal file
19
docs/source/concepts/resources.md
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
## Resources
|
||||||
|
|
||||||
|
Some of these APIs are associated with a set of **Resources**. Here is the mapping of APIs to resources:
|
||||||
|
|
||||||
|
- **Inference**, **Eval** and **Post Training** are associated with `Model` resources.
|
||||||
|
- **Safety** is associated with `Shield` resources.
|
||||||
|
- **Tool Runtime** is associated with `ToolGroup` resources.
|
||||||
|
- **DatasetIO** is associated with `Dataset` resources.
|
||||||
|
- **VectorIO** is associated with `VectorDB` resources.
|
||||||
|
- **Scoring** is associated with `ScoringFunction` resources.
|
||||||
|
- **Eval** is associated with `Model` and `Benchmark` resources.
|
||||||
|
|
||||||
|
Furthermore, we allow these resources to be **federated** across multiple providers. For example, you may have some Llama models served by Fireworks while others are served by AWS Bedrock. Regardless, they will all work seamlessly with the same uniform Inference API provided by Llama Stack.
|
||||||
|
|
||||||
|
```{admonition} Registering Resources
|
||||||
|
:class: tip
|
||||||
|
|
||||||
|
Given this architecture, it is necessary for the Stack to know which provider to use for a given resource. This means you need to explicitly _register_ resources (including models) before you can use them with the associated APIs.
|
||||||
|
```
|
||||||
|
|
@ -260,7 +260,41 @@ Containerfile created successfully in /tmp/tmp.viA3a3Rdsg/ContainerfileFROM pyth
|
||||||
You can now edit ~/meta-llama/llama-stack/tmp/configs/ollama-run.yaml and run `llama stack run ~/meta-llama/llama-stack/tmp/configs/ollama-run.yaml`
|
You can now edit ~/meta-llama/llama-stack/tmp/configs/ollama-run.yaml and run `llama stack run ~/meta-llama/llama-stack/tmp/configs/ollama-run.yaml`
|
||||||
```
|
```
|
||||||
|
|
||||||
After this step is successful, you should be able to find the built container image and test it with `llama stack run <path/to/run.yaml>`.
|
Now set some environment variables for the inference model ID and Llama Stack Port and create a local directory to mount into the container's file system.
|
||||||
|
```
|
||||||
|
export INFERENCE_MODEL="llama3.2:3b"
|
||||||
|
export LLAMA_STACK_PORT=8321
|
||||||
|
mkdir -p ~/.llama
|
||||||
|
```
|
||||||
|
|
||||||
|
After this step is successful, you should be able to find the built container image and test it with the below Docker command:
|
||||||
|
|
||||||
|
```
|
||||||
|
docker run -d \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ~/.llama:/root/.llama \
|
||||||
|
localhost/distribution-ollama:dev \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
|
--env OLLAMA_URL=http://host.docker.internal:11434
|
||||||
|
```
|
||||||
|
|
||||||
|
Here are the docker flags and their uses:
|
||||||
|
|
||||||
|
* `-d`: Runs the container in the detached mode as a background process
|
||||||
|
|
||||||
|
* `-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT`: Maps the container port to the host port for accessing the server
|
||||||
|
|
||||||
|
* `-v ~/.llama:/root/.llama`: Mounts the local .llama directory to persist configurations and data
|
||||||
|
|
||||||
|
* `localhost/distribution-ollama:dev`: The name and tag of the container image to run
|
||||||
|
|
||||||
|
* `--port $LLAMA_STACK_PORT`: Port number for the server to listen on
|
||||||
|
|
||||||
|
* `--env INFERENCE_MODEL=$INFERENCE_MODEL`: Sets the model to use for inference
|
||||||
|
|
||||||
|
* `--env OLLAMA_URL=http://host.docker.internal:11434`: Configures the URL for the Ollama service
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
::::
|
::::
|
||||||
|
|
|
||||||
32
docs/source/distributions/k8s/apply.sh
Executable file
32
docs/source/distributions/k8s/apply.sh
Executable file
|
|
@ -0,0 +1,32 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
export POSTGRES_USER=${POSTGRES_USER:-llamastack}
|
||||||
|
export POSTGRES_DB=${POSTGRES_DB:-llamastack}
|
||||||
|
export POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-llamastack}
|
||||||
|
|
||||||
|
export INFERENCE_MODEL=${INFERENCE_MODEL:-meta-llama/Llama-3.2-3B-Instruct}
|
||||||
|
export SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
set -x
|
||||||
|
|
||||||
|
envsubst < ./vllm-k8s.yaml.template | kubectl apply -f -
|
||||||
|
envsubst < ./vllm-safety-k8s.yaml.template | kubectl apply -f -
|
||||||
|
envsubst < ./postgres-k8s.yaml.template | kubectl apply -f -
|
||||||
|
envsubst < ./chroma-k8s.yaml.template | kubectl apply -f -
|
||||||
|
|
||||||
|
kubectl create configmap llama-stack-config --from-file=stack_run_config.yaml \
|
||||||
|
--dry-run=client -o yaml > stack-configmap.yaml
|
||||||
|
|
||||||
|
kubectl apply -f stack-configmap.yaml
|
||||||
|
|
||||||
|
envsubst < ./stack-k8s.yaml.template | kubectl apply -f -
|
||||||
|
envsubst < ./ingress-k8s.yaml.template | kubectl apply -f -
|
||||||
|
|
||||||
|
envsubst < ./ui-k8s.yaml.template | kubectl apply -f -
|
||||||
66
docs/source/distributions/k8s/chroma-k8s.yaml.template
Normal file
66
docs/source/distributions/k8s/chroma-k8s.yaml.template
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
apiVersion: v1
|
||||||
|
kind: PersistentVolumeClaim
|
||||||
|
metadata:
|
||||||
|
name: chromadb-pvc
|
||||||
|
spec:
|
||||||
|
accessModes:
|
||||||
|
- ReadWriteOnce
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
storage: 20Gi
|
||||||
|
---
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: chromadb
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app: chromadb
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app: chromadb
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: chromadb
|
||||||
|
image: chromadb/chroma:latest
|
||||||
|
ports:
|
||||||
|
- containerPort: 6000
|
||||||
|
env:
|
||||||
|
- name: CHROMA_HOST
|
||||||
|
value: "0.0.0.0"
|
||||||
|
- name: CHROMA_PORT
|
||||||
|
value: "6000"
|
||||||
|
- name: PERSIST_DIRECTORY
|
||||||
|
value: "/chroma/chroma"
|
||||||
|
- name: CHROMA_DB_IMPL
|
||||||
|
value: "duckdb+parquet"
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: "512Mi"
|
||||||
|
cpu: "250m"
|
||||||
|
limits:
|
||||||
|
memory: "2Gi"
|
||||||
|
cpu: "1000m"
|
||||||
|
volumeMounts:
|
||||||
|
- name: chromadb-storage
|
||||||
|
mountPath: /chroma/chroma
|
||||||
|
volumes:
|
||||||
|
- name: chromadb-storage
|
||||||
|
persistentVolumeClaim:
|
||||||
|
claimName: chromadb-pvc
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: chromadb
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
app: chromadb
|
||||||
|
ports:
|
||||||
|
- protocol: TCP
|
||||||
|
port: 6000
|
||||||
|
targetPort: 6000
|
||||||
|
type: ClusterIP
|
||||||
17
docs/source/distributions/k8s/ingress-k8s.yaml.template
Normal file
17
docs/source/distributions/k8s/ingress-k8s.yaml.template
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: llama-stack-service
|
||||||
|
spec:
|
||||||
|
type: LoadBalancer
|
||||||
|
selector:
|
||||||
|
app.kubernetes.io/name: llama-stack
|
||||||
|
ports:
|
||||||
|
- name: llama-stack-api
|
||||||
|
port: 8321
|
||||||
|
targetPort: 8321
|
||||||
|
protocol: TCP
|
||||||
|
- name: llama-stack-ui
|
||||||
|
port: 8322
|
||||||
|
targetPort: 8322
|
||||||
|
protocol: TCP
|
||||||
66
docs/source/distributions/k8s/postgres-k8s.yaml.template
Normal file
66
docs/source/distributions/k8s/postgres-k8s.yaml.template
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
apiVersion: v1
|
||||||
|
kind: PersistentVolumeClaim
|
||||||
|
metadata:
|
||||||
|
name: postgres-pvc
|
||||||
|
spec:
|
||||||
|
accessModes:
|
||||||
|
- ReadWriteOnce
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
storage: 10Gi
|
||||||
|
---
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: postgres
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app.kubernetes.io/name: postgres
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: postgres
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: postgres
|
||||||
|
image: postgres:15
|
||||||
|
env:
|
||||||
|
- name: POSTGRES_DB
|
||||||
|
value: "${POSTGRES_DB}"
|
||||||
|
- name: POSTGRES_USER
|
||||||
|
value: "${POSTGRES_USER}"
|
||||||
|
- name: POSTGRES_PASSWORD
|
||||||
|
value: "${POSTGRES_PASSWORD}"
|
||||||
|
- name: PGDATA
|
||||||
|
value: "/var/lib/postgresql/data/pgdata"
|
||||||
|
ports:
|
||||||
|
- containerPort: 5432
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: "512Mi"
|
||||||
|
cpu: "250m"
|
||||||
|
limits:
|
||||||
|
memory: "1Gi"
|
||||||
|
cpu: "500m"
|
||||||
|
volumeMounts:
|
||||||
|
- name: postgres-storage
|
||||||
|
mountPath: /var/lib/postgresql/data
|
||||||
|
volumes:
|
||||||
|
- name: postgres-storage
|
||||||
|
persistentVolumeClaim:
|
||||||
|
claimName: postgres-pvc
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: postgres-server
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
app.kubernetes.io/name: postgres
|
||||||
|
ports:
|
||||||
|
- protocol: TCP
|
||||||
|
port: 5432
|
||||||
|
targetPort: 5432
|
||||||
|
type: ClusterIP
|
||||||
128
docs/source/distributions/k8s/stack-configmap.yaml
Normal file
128
docs/source/distributions/k8s/stack-configmap.yaml
Normal file
|
|
@ -0,0 +1,128 @@
|
||||||
|
apiVersion: v1
|
||||||
|
data:
|
||||||
|
stack_run_config.yaml: |
|
||||||
|
version: '2'
|
||||||
|
image_name: kubernetes-demo
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- inference
|
||||||
|
- safety
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: vllm-inference
|
||||||
|
provider_type: remote::vllm
|
||||||
|
config:
|
||||||
|
url: ${env.VLLM_URL:http://localhost:8000/v1}
|
||||||
|
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
|
||||||
|
api_token: ${env.VLLM_API_TOKEN:fake}
|
||||||
|
tls_verify: ${env.VLLM_TLS_VERIFY:true}
|
||||||
|
- provider_id: vllm-safety
|
||||||
|
provider_type: remote::vllm
|
||||||
|
config:
|
||||||
|
url: ${env.VLLM_SAFETY_URL:http://localhost:8000/v1}
|
||||||
|
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
|
||||||
|
api_token: ${env.VLLM_API_TOKEN:fake}
|
||||||
|
tls_verify: ${env.VLLM_TLS_VERIFY:true}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
|
vector_io:
|
||||||
|
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
|
||||||
|
provider_type: remote::chromadb
|
||||||
|
config:
|
||||||
|
url: ${env.CHROMADB_URL:}
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config:
|
||||||
|
excluded_categories: []
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: postgres
|
||||||
|
host: ${env.POSTGRES_HOST:localhost}
|
||||||
|
port: ${env.POSTGRES_PORT:5432}
|
||||||
|
db: ${env.POSTGRES_DB:llamastack}
|
||||||
|
user: ${env.POSTGRES_USER:llamastack}
|
||||||
|
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||||
|
responses_store:
|
||||||
|
type: postgres
|
||||||
|
host: ${env.POSTGRES_HOST:localhost}
|
||||||
|
port: ${env.POSTGRES_PORT:5432}
|
||||||
|
db: ${env.POSTGRES_DB:llamastack}
|
||||||
|
user: ${env.POSTGRES_USER:llamastack}
|
||||||
|
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:console}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: brave-search
|
||||||
|
provider_type: remote::brave-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: tavily-search
|
||||||
|
provider_type: remote::tavily-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
config: {}
|
||||||
|
- provider_id: model-context-protocol
|
||||||
|
provider_type: remote::model-context-protocol
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
type: postgres
|
||||||
|
host: ${env.POSTGRES_HOST:localhost}
|
||||||
|
port: ${env.POSTGRES_PORT:5432}
|
||||||
|
db: ${env.POSTGRES_DB:llamastack}
|
||||||
|
user: ${env.POSTGRES_USER:llamastack}
|
||||||
|
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||||
|
table_name: llamastack_kvstore
|
||||||
|
inference_store:
|
||||||
|
type: postgres
|
||||||
|
host: ${env.POSTGRES_HOST:localhost}
|
||||||
|
port: ${env.POSTGRES_PORT:5432}
|
||||||
|
db: ${env.POSTGRES_DB:llamastack}
|
||||||
|
user: ${env.POSTGRES_USER:llamastack}
|
||||||
|
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||||
|
models:
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
model_type: embedding
|
||||||
|
- metadata: {}
|
||||||
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
|
provider_id: vllm-inference
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: ${env.SAFETY_MODEL:meta-llama/Llama-Guard-3-1B}
|
||||||
|
provider_id: vllm-safety
|
||||||
|
model_type: llm
|
||||||
|
shields:
|
||||||
|
- shield_id: ${env.SAFETY_MODEL:meta-llama/Llama-Guard-3-1B}
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::websearch
|
||||||
|
provider_id: tavily-search
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
server:
|
||||||
|
port: 8321
|
||||||
|
kind: ConfigMap
|
||||||
|
metadata:
|
||||||
|
creationTimestamp: null
|
||||||
|
name: llama-stack-config
|
||||||
69
docs/source/distributions/k8s/stack-k8s.yaml.template
Normal file
69
docs/source/distributions/k8s/stack-k8s.yaml.template
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
apiVersion: v1
|
||||||
|
kind: PersistentVolumeClaim
|
||||||
|
metadata:
|
||||||
|
name: llama-pvc
|
||||||
|
spec:
|
||||||
|
accessModes:
|
||||||
|
- ReadWriteOnce
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
storage: 1Gi
|
||||||
|
---
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: llama-stack-server
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app.kubernetes.io/name: llama-stack
|
||||||
|
app.kubernetes.io/component: server
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: llama-stack
|
||||||
|
app.kubernetes.io/component: server
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: llama-stack
|
||||||
|
image: llamastack/distribution-remote-vllm:latest
|
||||||
|
imagePullPolicy: Always # since we have specified latest instead of a version
|
||||||
|
env:
|
||||||
|
- name: ENABLE_CHROMADB
|
||||||
|
value: "true"
|
||||||
|
- name: CHROMADB_URL
|
||||||
|
value: http://chromadb.default.svc.cluster.local:6000
|
||||||
|
- name: VLLM_URL
|
||||||
|
value: http://vllm-server.default.svc.cluster.local:8000/v1
|
||||||
|
- name: VLLM_MAX_TOKENS
|
||||||
|
value: "3072"
|
||||||
|
- name: VLLM_SAFETY_URL
|
||||||
|
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
||||||
|
- name: POSTGRES_HOST
|
||||||
|
value: postgres-server.default.svc.cluster.local
|
||||||
|
- name: POSTGRES_PORT
|
||||||
|
value: "5432"
|
||||||
|
- name: VLLM_TLS_VERIFY
|
||||||
|
value: "false"
|
||||||
|
- name: INFERENCE_MODEL
|
||||||
|
value: "${INFERENCE_MODEL}"
|
||||||
|
- name: SAFETY_MODEL
|
||||||
|
value: "${SAFETY_MODEL}"
|
||||||
|
- name: TAVILY_SEARCH_API_KEY
|
||||||
|
value: "${TAVILY_SEARCH_API_KEY}"
|
||||||
|
command: ["python", "-m", "llama_stack.distribution.server.server", "--config", "/etc/config/stack_run_config.yaml", "--port", "8321"]
|
||||||
|
ports:
|
||||||
|
- containerPort: 8321
|
||||||
|
volumeMounts:
|
||||||
|
- name: llama-storage
|
||||||
|
mountPath: /root/.llama
|
||||||
|
- name: llama-config
|
||||||
|
mountPath: /etc/config
|
||||||
|
volumes:
|
||||||
|
- name: llama-storage
|
||||||
|
persistentVolumeClaim:
|
||||||
|
claimName: llama-pvc
|
||||||
|
- name: llama-config
|
||||||
|
configMap:
|
||||||
|
name: llama-stack-config
|
||||||
121
docs/source/distributions/k8s/stack_run_config.yaml
Normal file
121
docs/source/distributions/k8s/stack_run_config.yaml
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: kubernetes-demo
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- inference
|
||||||
|
- safety
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: vllm-inference
|
||||||
|
provider_type: remote::vllm
|
||||||
|
config:
|
||||||
|
url: ${env.VLLM_URL:http://localhost:8000/v1}
|
||||||
|
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
|
||||||
|
api_token: ${env.VLLM_API_TOKEN:fake}
|
||||||
|
tls_verify: ${env.VLLM_TLS_VERIFY:true}
|
||||||
|
- provider_id: vllm-safety
|
||||||
|
provider_type: remote::vllm
|
||||||
|
config:
|
||||||
|
url: ${env.VLLM_SAFETY_URL:http://localhost:8000/v1}
|
||||||
|
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
|
||||||
|
api_token: ${env.VLLM_API_TOKEN:fake}
|
||||||
|
tls_verify: ${env.VLLM_TLS_VERIFY:true}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
|
vector_io:
|
||||||
|
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
|
||||||
|
provider_type: remote::chromadb
|
||||||
|
config:
|
||||||
|
url: ${env.CHROMADB_URL:}
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config:
|
||||||
|
excluded_categories: []
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: postgres
|
||||||
|
host: ${env.POSTGRES_HOST:localhost}
|
||||||
|
port: ${env.POSTGRES_PORT:5432}
|
||||||
|
db: ${env.POSTGRES_DB:llamastack}
|
||||||
|
user: ${env.POSTGRES_USER:llamastack}
|
||||||
|
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||||
|
responses_store:
|
||||||
|
type: postgres
|
||||||
|
host: ${env.POSTGRES_HOST:localhost}
|
||||||
|
port: ${env.POSTGRES_PORT:5432}
|
||||||
|
db: ${env.POSTGRES_DB:llamastack}
|
||||||
|
user: ${env.POSTGRES_USER:llamastack}
|
||||||
|
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:console}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: brave-search
|
||||||
|
provider_type: remote::brave-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: tavily-search
|
||||||
|
provider_type: remote::tavily-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
config: {}
|
||||||
|
- provider_id: model-context-protocol
|
||||||
|
provider_type: remote::model-context-protocol
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
type: postgres
|
||||||
|
host: ${env.POSTGRES_HOST:localhost}
|
||||||
|
port: ${env.POSTGRES_PORT:5432}
|
||||||
|
db: ${env.POSTGRES_DB:llamastack}
|
||||||
|
user: ${env.POSTGRES_USER:llamastack}
|
||||||
|
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||||
|
table_name: llamastack_kvstore
|
||||||
|
inference_store:
|
||||||
|
type: postgres
|
||||||
|
host: ${env.POSTGRES_HOST:localhost}
|
||||||
|
port: ${env.POSTGRES_PORT:5432}
|
||||||
|
db: ${env.POSTGRES_DB:llamastack}
|
||||||
|
user: ${env.POSTGRES_USER:llamastack}
|
||||||
|
password: ${env.POSTGRES_PASSWORD:llamastack}
|
||||||
|
models:
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
model_type: embedding
|
||||||
|
- metadata: {}
|
||||||
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
|
provider_id: vllm-inference
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: ${env.SAFETY_MODEL:meta-llama/Llama-Guard-3-1B}
|
||||||
|
provider_id: vllm-safety
|
||||||
|
model_type: llm
|
||||||
|
shields:
|
||||||
|
- shield_id: ${env.SAFETY_MODEL:meta-llama/Llama-Guard-3-1B}
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::websearch
|
||||||
|
provider_id: tavily-search
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
server:
|
||||||
|
port: 8321
|
||||||
62
docs/source/distributions/k8s/ui-k8s.yaml.template
Normal file
62
docs/source/distributions/k8s/ui-k8s.yaml.template
Normal file
|
|
@ -0,0 +1,62 @@
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: llama-stack-ui
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: llama-stack
|
||||||
|
app.kubernetes.io/component: ui
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app.kubernetes.io/name: llama-stack
|
||||||
|
app.kubernetes.io/component: ui
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: llama-stack
|
||||||
|
app.kubernetes.io/component: ui
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: llama-stack-ui
|
||||||
|
image: node:18-alpine
|
||||||
|
command: ["/bin/sh"]
|
||||||
|
env:
|
||||||
|
- name: LLAMA_STACK_BACKEND_URL
|
||||||
|
value: "http://llama-stack-service:8321"
|
||||||
|
- name: LLAMA_STACK_UI_PORT
|
||||||
|
value: "8322"
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
# Install git (not included in alpine by default)
|
||||||
|
apk add --no-cache git
|
||||||
|
|
||||||
|
# Clone the repository
|
||||||
|
echo "Cloning repository..."
|
||||||
|
git clone https://github.com/meta-llama/llama-stack.git /app
|
||||||
|
|
||||||
|
# Navigate to the UI directory
|
||||||
|
echo "Navigating to UI directory..."
|
||||||
|
cd /app/llama_stack/ui
|
||||||
|
|
||||||
|
# Check if package.json exists
|
||||||
|
if [ ! -f "package.json" ]; then
|
||||||
|
echo "ERROR: package.json not found in $(pwd)"
|
||||||
|
ls -la
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install dependencies with verbose output
|
||||||
|
echo "Installing dependencies..."
|
||||||
|
npm install --verbose
|
||||||
|
|
||||||
|
# Verify next is installed
|
||||||
|
echo "Checking if next is installed..."
|
||||||
|
npx next --version || echo "Next.js not found, checking node_modules..."
|
||||||
|
ls -la node_modules/.bin/ | grep next || echo "No next binary found"
|
||||||
|
|
||||||
|
npm run dev
|
||||||
|
ports:
|
||||||
|
- containerPort: 8322
|
||||||
|
workingDir: /app
|
||||||
71
docs/source/distributions/k8s/vllm-k8s.yaml.template
Normal file
71
docs/source/distributions/k8s/vllm-k8s.yaml.template
Normal file
|
|
@ -0,0 +1,71 @@
|
||||||
|
apiVersion: v1
|
||||||
|
kind: PersistentVolumeClaim
|
||||||
|
metadata:
|
||||||
|
name: vllm-models
|
||||||
|
spec:
|
||||||
|
accessModes:
|
||||||
|
- ReadWriteOnce
|
||||||
|
volumeMode: Filesystem
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
storage: 50Gi
|
||||||
|
---
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: vllm-server
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app.kubernetes.io/name: vllm
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: vllm
|
||||||
|
workload-type: inference
|
||||||
|
spec:
|
||||||
|
affinity:
|
||||||
|
podAntiAffinity:
|
||||||
|
requiredDuringSchedulingIgnoredDuringExecution:
|
||||||
|
- labelSelector:
|
||||||
|
matchExpressions:
|
||||||
|
- key: workload-type
|
||||||
|
operator: In
|
||||||
|
values:
|
||||||
|
- inference
|
||||||
|
topologyKey: kubernetes.io/hostname # Ensures no two inference pods on same node
|
||||||
|
containers:
|
||||||
|
- name: vllm
|
||||||
|
image: vllm/vllm-openai:latest
|
||||||
|
command: ["/bin/sh", "-c"]
|
||||||
|
args:
|
||||||
|
- "vllm serve ${INFERENCE_MODEL} --dtype float16 --enforce-eager --max-model-len 4096 --gpu-memory-utilization 0.6"
|
||||||
|
env:
|
||||||
|
- name: HUGGING_FACE_HUB_TOKEN
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: hf-token-secret
|
||||||
|
key: token
|
||||||
|
ports:
|
||||||
|
- containerPort: 8000
|
||||||
|
volumeMounts:
|
||||||
|
- name: llama-storage
|
||||||
|
mountPath: /root/.cache/huggingface
|
||||||
|
volumes:
|
||||||
|
- name: llama-storage
|
||||||
|
persistentVolumeClaim:
|
||||||
|
claimName: vllm-models
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: vllm-server
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
app.kubernetes.io/name: vllm
|
||||||
|
ports:
|
||||||
|
- protocol: TCP
|
||||||
|
port: 8000
|
||||||
|
targetPort: 8000
|
||||||
|
type: ClusterIP
|
||||||
73
docs/source/distributions/k8s/vllm-safety-k8s.yaml.template
Normal file
73
docs/source/distributions/k8s/vllm-safety-k8s.yaml.template
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
apiVersion: v1
|
||||||
|
kind: PersistentVolumeClaim
|
||||||
|
metadata:
|
||||||
|
name: vllm-models-safety
|
||||||
|
spec:
|
||||||
|
accessModes:
|
||||||
|
- ReadWriteOnce
|
||||||
|
volumeMode: Filesystem
|
||||||
|
storageClassName: gp2
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
storage: 30Gi
|
||||||
|
---
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: vllm-server-safety
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app.kubernetes.io/name: vllm-safety
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: vllm-safety
|
||||||
|
workload-type: inference
|
||||||
|
spec:
|
||||||
|
affinity:
|
||||||
|
podAntiAffinity:
|
||||||
|
requiredDuringSchedulingIgnoredDuringExecution:
|
||||||
|
- labelSelector:
|
||||||
|
matchExpressions:
|
||||||
|
- key: workload-type
|
||||||
|
operator: In
|
||||||
|
values:
|
||||||
|
- inference
|
||||||
|
topologyKey: kubernetes.io/hostname # Ensures no two inference pods on same node
|
||||||
|
containers:
|
||||||
|
- name: vllm-safety
|
||||||
|
image: vllm/vllm-openai:latest
|
||||||
|
command: ["/bin/sh", "-c"]
|
||||||
|
args: [
|
||||||
|
"vllm serve ${SAFETY_MODEL} --dtype float16 --enforce-eager --max-model-len 4096 --port 8001 --gpu-memory-utilization 0.3"
|
||||||
|
]
|
||||||
|
env:
|
||||||
|
- name: HUGGING_FACE_HUB_TOKEN
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: hf-token-secret
|
||||||
|
key: token
|
||||||
|
ports:
|
||||||
|
- containerPort: 8001
|
||||||
|
volumeMounts:
|
||||||
|
- name: llama-storage
|
||||||
|
mountPath: /root/.cache/huggingface
|
||||||
|
volumes:
|
||||||
|
- name: llama-storage
|
||||||
|
persistentVolumeClaim:
|
||||||
|
claimName: vllm-models-safety
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: vllm-server-safety
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
app.kubernetes.io/name: vllm-safety
|
||||||
|
ports:
|
||||||
|
- protocol: TCP
|
||||||
|
port: 8001
|
||||||
|
targetPort: 8001
|
||||||
|
type: ClusterIP
|
||||||
|
|
@ -18,6 +18,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
|
| files | `inline::localfs` |
|
||||||
| inference | `remote::fireworks`, `inline::sentence-transformers` |
|
| inference | `remote::fireworks`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ for log in AgentEventLogger().log(response):
|
||||||
```
|
```
|
||||||
We will use `uv` to run the script
|
We will use `uv` to run the script
|
||||||
```
|
```
|
||||||
uv run --with llama-stack-client demo_script.py
|
uv run --with llama-stack-client,fire,requests demo_script.py
|
||||||
```
|
```
|
||||||
And you should see output like below.
|
And you should see output like below.
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -103,6 +103,7 @@ getting_started/index
|
||||||
getting_started/detailed_tutorial
|
getting_started/detailed_tutorial
|
||||||
introduction/index
|
introduction/index
|
||||||
concepts/index
|
concepts/index
|
||||||
|
openai/index
|
||||||
providers/index
|
providers/index
|
||||||
distributions/index
|
distributions/index
|
||||||
building_applications/index
|
building_applications/index
|
||||||
|
|
|
||||||
193
docs/source/openai/index.md
Normal file
193
docs/source/openai/index.md
Normal file
|
|
@ -0,0 +1,193 @@
|
||||||
|
# OpenAI API Compatibility
|
||||||
|
|
||||||
|
## Server path
|
||||||
|
|
||||||
|
Llama Stack exposes an OpenAI-compatible API endpoint at `/v1/openai/v1`. So, for a Llama Stack server running locally on port `8321`, the full url to the OpenAI-compatible API endpoint is `http://localhost:8321/v1/openai/v1`.
|
||||||
|
|
||||||
|
## Clients
|
||||||
|
|
||||||
|
You should be able to use any client that speaks OpenAI APIs with Llama Stack. We regularly test with the official Llama Stack clients as well as OpenAI's official Python client.
|
||||||
|
|
||||||
|
### Llama Stack Client
|
||||||
|
|
||||||
|
When using the Llama Stack client, set the `base_url` to the root of your Llama Stack server. It will automatically route OpenAI-compatible requests to the right server endpoint for you.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenAI Client
|
||||||
|
|
||||||
|
When using an OpenAI client, set the `base_url` to the `/v1/openai/v1` path on your Llama Stack server.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(base_url="http://localhost:8321/v1/openai/v1", api_key="none")
|
||||||
|
```
|
||||||
|
|
||||||
|
Regardless of the client you choose, the following code examples should all work the same.
|
||||||
|
|
||||||
|
## APIs implemented
|
||||||
|
|
||||||
|
### Models
|
||||||
|
|
||||||
|
Many of the APIs require you to pass in a model parameter. To see the list of models available in your Llama Stack server:
|
||||||
|
|
||||||
|
```python
|
||||||
|
models = client.models.list()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Responses
|
||||||
|
|
||||||
|
:::{note}
|
||||||
|
The Responses API implementation is still in active development. While it is quite usable, there are still unimplemented parts of the API. We'd love feedback on any use-cases you try that do not work to help prioritize the pieces left to implement. Please open issues in the [meta-llama/llama-stack](https://github.com/meta-llama/llama-stack) GitHub repository with details of anything that does not work.
|
||||||
|
:::
|
||||||
|
|
||||||
|
#### Simple inference
|
||||||
|
|
||||||
|
Request:
|
||||||
|
|
||||||
|
```
|
||||||
|
response = client.responses.create(
|
||||||
|
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
input="Write a haiku about coding."
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.output_text)
|
||||||
|
```
|
||||||
|
Example output:
|
||||||
|
|
||||||
|
```text
|
||||||
|
Pixels dancing slow
|
||||||
|
Syntax whispers secrets sweet
|
||||||
|
Code's gentle silence
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Structured Output
|
||||||
|
|
||||||
|
Request:
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = client.responses.create(
|
||||||
|
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
input=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Extract the participants from the event information.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Alice and Bob are going to a science fair on Friday.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
text={
|
||||||
|
"format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"name": "participants",
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"participants": {"type": "array", "items": {"type": "string"}}
|
||||||
|
},
|
||||||
|
"required": ["participants"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.output_text)
|
||||||
|
```
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
|
||||||
|
```text
|
||||||
|
{ "participants": ["Alice", "Bob"] }
|
||||||
|
```
|
||||||
|
|
||||||
|
### Chat Completions
|
||||||
|
|
||||||
|
#### Simple inference
|
||||||
|
|
||||||
|
Request:
|
||||||
|
|
||||||
|
```python
|
||||||
|
chat_completion = client.chat.completions.create(
|
||||||
|
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
messages=[{"role": "user", "content": "Write a haiku about coding."}],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(chat_completion.choices[0].message.content)
|
||||||
|
```
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
|
||||||
|
```text
|
||||||
|
Lines of code unfold
|
||||||
|
Logic flows like a river
|
||||||
|
Code's gentle beauty
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Structured Output
|
||||||
|
|
||||||
|
Request:
|
||||||
|
|
||||||
|
```python
|
||||||
|
chat_completion = client.chat.completions.create(
|
||||||
|
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Extract the participants from the event information.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Alice and Bob are going to a science fair on Friday.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
response_format={
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": "participants",
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"participants": {"type": "array", "items": {"type": "string"}}
|
||||||
|
},
|
||||||
|
"required": ["participants"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(chat_completion.choices[0].message.content)
|
||||||
|
```
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
|
||||||
|
```text
|
||||||
|
{ "participants": ["Alice", "Bob"] }
|
||||||
|
```
|
||||||
|
|
||||||
|
### Completions
|
||||||
|
|
||||||
|
#### Simple inference
|
||||||
|
|
||||||
|
Request:
|
||||||
|
|
||||||
|
```python
|
||||||
|
completion = client.completions.create(
|
||||||
|
model="meta-llama/Llama-3.2-3B-Instruct", prompt="Write a haiku about coding."
|
||||||
|
)
|
||||||
|
|
||||||
|
print(completion.choices[0].text)
|
||||||
|
```
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
|
||||||
|
```text
|
||||||
|
Lines of code unfurl
|
||||||
|
Logic whispers in the dark
|
||||||
|
Art in hidden form
|
||||||
|
```
|
||||||
|
|
@ -37,6 +37,7 @@ from .openai_responses import (
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
|
OpenAIResponseText,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||||
|
|
@ -603,7 +604,9 @@ class Agents(Protocol):
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
|
text: OpenAIResponseText | None = None,
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
|
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
"""Create a new OpenAI response.
|
"""Create a new OpenAI response.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
|
@ -126,6 +127,32 @@ OpenAIResponseOutput = Annotated[
|
||||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
|
|
||||||
|
|
||||||
|
# This has to be a TypedDict because we need a "schema" field and our strong
|
||||||
|
# typing code in the schema generator doesn't support Pydantic aliases. That also
|
||||||
|
# means we can't use a discriminator field here, because TypedDicts don't support
|
||||||
|
# default values which the strong typing code requires for discriminators.
|
||||||
|
class OpenAIResponseTextFormat(TypedDict, total=False):
|
||||||
|
"""Configuration for Responses API text format.
|
||||||
|
|
||||||
|
:param type: Must be "text", "json_schema", or "json_object" to identify the format type
|
||||||
|
:param name: The name of the response format. Only used for json_schema.
|
||||||
|
:param schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model. Only used for json_schema.
|
||||||
|
:param description: (Optional) A description of the response format. Only used for json_schema.
|
||||||
|
:param strict: (Optional) Whether to strictly enforce the JSON schema. If true, the response must match the schema exactly. Only used for json_schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["text"] | Literal["json_schema"] | Literal["json_object"]
|
||||||
|
name: str | None
|
||||||
|
schema: dict[str, Any] | None
|
||||||
|
description: str | None
|
||||||
|
strict: bool | None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseText(BaseModel):
|
||||||
|
format: OpenAIResponseTextFormat | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseObject(BaseModel):
|
class OpenAIResponseObject(BaseModel):
|
||||||
created_at: int
|
created_at: int
|
||||||
|
|
@ -138,6 +165,9 @@ class OpenAIResponseObject(BaseModel):
|
||||||
previous_response_id: str | None = None
|
previous_response_id: str | None = None
|
||||||
status: str
|
status: str
|
||||||
temperature: float | None = None
|
temperature: float | None = None
|
||||||
|
# Default to text format to avoid breaking the loading of old responses
|
||||||
|
# before the field was added. New responses will have this set always.
|
||||||
|
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||||
top_p: float | None = None
|
top_p: float | None = None
|
||||||
truncation: str | None = None
|
truncation: str | None = None
|
||||||
user: str | None = None
|
user: str | None = None
|
||||||
|
|
@ -149,6 +179,30 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
||||||
type: Literal["response.created"] = "response.created"
|
type: Literal["response.created"] = "response.created"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||||
|
response: OpenAIResponseObject
|
||||||
|
type: Literal["response.completed"] = "response.completed"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseOutputItemAdded(BaseModel):
|
||||||
|
response_id: str
|
||||||
|
item: OpenAIResponseOutput
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.output_item.added"] = "response.output_item.added"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseOutputItemDone(BaseModel):
|
||||||
|
response_id: str
|
||||||
|
item: OpenAIResponseOutput
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.output_item.done"] = "response.output_item.done"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
||||||
content_index: int
|
content_index: int
|
||||||
|
|
@ -160,14 +214,132 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
class OpenAIResponseObjectStreamResponseOutputTextDone(BaseModel):
|
||||||
response: OpenAIResponseObject
|
content_index: int
|
||||||
type: Literal["response.completed"] = "response.completed"
|
text: str # final text of the output item
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.output_text.done"] = "response.output_text.done"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(BaseModel):
|
||||||
|
delta: str
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.function_call_arguments.delta"] = "response.function_call_arguments.delta"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(BaseModel):
|
||||||
|
arguments: str # final arguments of the function call
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.function_call_arguments.done"] = "response.function_call_arguments.done"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseWebSearchCallInProgress(BaseModel):
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.web_search_call.in_progress"] = "response.web_search_call.in_progress"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseWebSearchCallSearching(BaseModel):
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.web_search_call.searching"] = "response.web_search_call.searching"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseWebSearchCallCompleted(BaseModel):
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.web_search_call.completed"] = "response.web_search_call.completed"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseMcpListToolsInProgress(BaseModel):
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.mcp_list_tools.in_progress"] = "response.mcp_list_tools.in_progress"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseMcpListToolsFailed(BaseModel):
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.mcp_list_tools.failed"] = "response.mcp_list_tools.failed"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseMcpListToolsCompleted(BaseModel):
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.mcp_list_tools.completed"] = "response.mcp_list_tools.completed"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(BaseModel):
|
||||||
|
delta: str
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.mcp_call.arguments.delta"] = "response.mcp_call.arguments.delta"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseMcpCallArgumentsDone(BaseModel):
|
||||||
|
arguments: str # final arguments of the MCP call
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.mcp_call.arguments.done"] = "response.mcp_call.arguments.done"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseMcpCallInProgress(BaseModel):
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.mcp_call.in_progress"] = "response.mcp_call.in_progress"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseMcpCallFailed(BaseModel):
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.mcp_call.failed"] = "response.mcp_call.failed"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel):
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed"
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseObjectStream = Annotated[
|
OpenAIResponseObjectStream = Annotated[
|
||||||
OpenAIResponseObjectStreamResponseCreated
|
OpenAIResponseObjectStreamResponseCreated
|
||||||
|
| OpenAIResponseObjectStreamResponseOutputItemAdded
|
||||||
|
| OpenAIResponseObjectStreamResponseOutputItemDone
|
||||||
| OpenAIResponseObjectStreamResponseOutputTextDelta
|
| OpenAIResponseObjectStreamResponseOutputTextDelta
|
||||||
|
| OpenAIResponseObjectStreamResponseOutputTextDone
|
||||||
|
| OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta
|
||||||
|
| OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone
|
||||||
|
| OpenAIResponseObjectStreamResponseWebSearchCallInProgress
|
||||||
|
| OpenAIResponseObjectStreamResponseWebSearchCallSearching
|
||||||
|
| OpenAIResponseObjectStreamResponseWebSearchCallCompleted
|
||||||
|
| OpenAIResponseObjectStreamResponseMcpListToolsInProgress
|
||||||
|
| OpenAIResponseObjectStreamResponseMcpListToolsFailed
|
||||||
|
| OpenAIResponseObjectStreamResponseMcpListToolsCompleted
|
||||||
|
| OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta
|
||||||
|
| OpenAIResponseObjectStreamResponseMcpCallArgumentsDone
|
||||||
|
| OpenAIResponseObjectStreamResponseMcpCallInProgress
|
||||||
|
| OpenAIResponseObjectStreamResponseMcpCallFailed
|
||||||
|
| OpenAIResponseObjectStreamResponseMcpCallCompleted
|
||||||
| OpenAIResponseObjectStreamResponseCompleted,
|
| OpenAIResponseObjectStreamResponseCompleted,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -4,179 +4,158 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Protocol, runtime_checkable
|
from enum import Enum
|
||||||
|
from typing import Annotated, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from fastapi import File, Form, Response, UploadFile
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
# OpenAI Files API Models
|
||||||
class FileUploadResponse(BaseModel):
|
class OpenAIFilePurpose(str, Enum):
|
||||||
|
"""
|
||||||
|
Valid purpose values for OpenAI Files API.
|
||||||
"""
|
"""
|
||||||
Response after initiating a file upload session.
|
|
||||||
|
|
||||||
:param id: ID of the upload session
|
ASSISTANTS = "assistants"
|
||||||
:param url: Upload URL for the file or file parts
|
# TODO: Add other purposes as needed
|
||||||
:param offset: Upload content offset
|
|
||||||
:param size: Upload content size
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIFileObject(BaseModel):
|
||||||
|
"""
|
||||||
|
OpenAI File object as defined in the OpenAI Files API.
|
||||||
|
|
||||||
|
:param object: The object type, which is always "file"
|
||||||
|
:param id: The file identifier, which can be referenced in the API endpoints
|
||||||
|
:param bytes: The size of the file, in bytes
|
||||||
|
:param created_at: The Unix timestamp (in seconds) for when the file was created
|
||||||
|
:param expires_at: The Unix timestamp (in seconds) for when the file expires
|
||||||
|
:param filename: The name of the file
|
||||||
|
:param purpose: The intended purpose of the file
|
||||||
|
"""
|
||||||
|
|
||||||
|
object: Literal["file"] = "file"
|
||||||
|
id: str
|
||||||
|
bytes: int
|
||||||
|
created_at: int
|
||||||
|
expires_at: int
|
||||||
|
filename: str
|
||||||
|
purpose: OpenAIFilePurpose
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ListOpenAIFileResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Response for listing files in OpenAI Files API.
|
||||||
|
|
||||||
|
:param data: List of file objects
|
||||||
|
:param object: The object type, which is always "list"
|
||||||
|
"""
|
||||||
|
|
||||||
|
data: list[OpenAIFileObject]
|
||||||
|
has_more: bool
|
||||||
|
first_id: str
|
||||||
|
last_id: str
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIFileDeleteResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Response for deleting a file in OpenAI Files API.
|
||||||
|
|
||||||
|
:param id: The file identifier that was deleted
|
||||||
|
:param object: The object type, which is always "file"
|
||||||
|
:param deleted: Whether the file was successfully deleted
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
url: str
|
object: Literal["file"] = "file"
|
||||||
offset: int
|
deleted: bool
|
||||||
size: int
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class BucketResponse(BaseModel):
|
|
||||||
name: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListBucketResponse(BaseModel):
|
|
||||||
"""
|
|
||||||
Response representing a list of file entries.
|
|
||||||
|
|
||||||
:param data: List of FileResponse entries
|
|
||||||
"""
|
|
||||||
|
|
||||||
data: list[BucketResponse]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class FileResponse(BaseModel):
|
|
||||||
"""
|
|
||||||
Response representing a file entry.
|
|
||||||
|
|
||||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
|
|
||||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
|
||||||
:param mime_type: MIME type of the file
|
|
||||||
:param url: Upload URL for the file contents
|
|
||||||
:param bytes: Size of the file in bytes
|
|
||||||
:param created_at: Timestamp of when the file was created
|
|
||||||
"""
|
|
||||||
|
|
||||||
bucket: str
|
|
||||||
key: str
|
|
||||||
mime_type: str
|
|
||||||
url: str
|
|
||||||
bytes: int
|
|
||||||
created_at: int
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListFileResponse(BaseModel):
|
|
||||||
"""
|
|
||||||
Response representing a list of file entries.
|
|
||||||
|
|
||||||
:param data: List of FileResponse entries
|
|
||||||
"""
|
|
||||||
|
|
||||||
data: list[FileResponse]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Files(Protocol):
|
class Files(Protocol):
|
||||||
@webmethod(route="/files", method="POST")
|
# OpenAI Files API Endpoints
|
||||||
async def create_upload_session(
|
@webmethod(route="/openai/v1/files", method="POST")
|
||||||
|
async def openai_upload_file(
|
||||||
self,
|
self,
|
||||||
bucket: str,
|
file: Annotated[UploadFile, File()],
|
||||||
key: str,
|
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||||
mime_type: str,
|
) -> OpenAIFileObject:
|
||||||
size: int,
|
|
||||||
) -> FileUploadResponse:
|
|
||||||
"""
|
"""
|
||||||
Create a new upload session for a file identified by a bucket and key.
|
Upload a file that can be used across various endpoints.
|
||||||
|
|
||||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-).
|
The file upload should be a multipart form request with:
|
||||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
- file: The File object (not file name) to be uploaded.
|
||||||
:param mime_type: MIME type of the file.
|
- purpose: The intended purpose of the uploaded file.
|
||||||
:param size: File size in bytes.
|
|
||||||
:returns: A FileUploadResponse.
|
:param file: The uploaded file object containing content and metadata (filename, content_type, etc.).
|
||||||
|
:param purpose: The intended purpose of the uploaded file (e.g., "assistants", "fine-tune").
|
||||||
|
:returns: An OpenAIFileObject representing the uploaded file.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/files/session:{upload_id}", method="POST", raw_bytes_request_body=True)
|
@webmethod(route="/openai/v1/files", method="GET")
|
||||||
async def upload_content_to_session(
|
async def openai_list_files(
|
||||||
self,
|
self,
|
||||||
upload_id: str,
|
after: str | None = None,
|
||||||
) -> FileResponse | None:
|
limit: int | None = 10000,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
purpose: OpenAIFilePurpose | None = None,
|
||||||
|
) -> ListOpenAIFileResponse:
|
||||||
"""
|
"""
|
||||||
Upload file content to an existing upload session.
|
Returns a list of files that belong to the user's organization.
|
||||||
On the server, request body will have the raw bytes that are uploaded.
|
|
||||||
|
|
||||||
:param upload_id: ID of the upload session.
|
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
|
||||||
:returns: A FileResponse or None if the upload is not complete.
|
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 10,000, and the default is 10,000.
|
||||||
|
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
|
||||||
|
:param purpose: Only return files with the given purpose.
|
||||||
|
:returns: An ListOpenAIFileResponse containing the list of files.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/files/session:{upload_id}", method="GET")
|
@webmethod(route="/openai/v1/files/{file_id}", method="GET")
|
||||||
async def get_upload_session_info(
|
async def openai_retrieve_file(
|
||||||
self,
|
self,
|
||||||
upload_id: str,
|
file_id: str,
|
||||||
) -> FileUploadResponse:
|
) -> OpenAIFileObject:
|
||||||
"""
|
"""
|
||||||
Returns information about an existsing upload session.
|
Returns information about a specific file.
|
||||||
|
|
||||||
:param upload_id: ID of the upload session.
|
:param file_id: The ID of the file to use for this request.
|
||||||
:returns: A FileUploadResponse.
|
:returns: An OpenAIFileObject containing file information.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/files", method="GET")
|
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE")
|
||||||
async def list_all_buckets(
|
async def openai_delete_file(
|
||||||
self,
|
self,
|
||||||
bucket: str,
|
file_id: str,
|
||||||
) -> ListBucketResponse:
|
) -> OpenAIFileDeleteResponse:
|
||||||
"""
|
"""
|
||||||
List all buckets.
|
Delete a file.
|
||||||
|
|
||||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
:param file_id: The ID of the file to use for this request.
|
||||||
:returns: A ListBucketResponse.
|
:returns: An OpenAIFileDeleteResponse indicating successful deletion.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/files/{bucket}", method="GET")
|
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET")
|
||||||
async def list_files_in_bucket(
|
async def openai_retrieve_file_content(
|
||||||
self,
|
self,
|
||||||
bucket: str,
|
file_id: str,
|
||||||
) -> ListFileResponse:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
List all files in a bucket.
|
Returns the contents of the specified file.
|
||||||
|
|
||||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
:param file_id: The ID of the file to use for this request.
|
||||||
:returns: A ListFileResponse.
|
:returns: The raw file content as a binary response.
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/files/{bucket}/{key:path}", method="GET")
|
|
||||||
async def get_file(
|
|
||||||
self,
|
|
||||||
bucket: str,
|
|
||||||
key: str,
|
|
||||||
) -> FileResponse:
|
|
||||||
"""
|
|
||||||
Get a file info identified by a bucket and key.
|
|
||||||
|
|
||||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
|
||||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
|
||||||
:returns: A FileResponse.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/files/{bucket}/{key:path}", method="DELETE")
|
|
||||||
async def delete_file(
|
|
||||||
self,
|
|
||||||
bucket: str,
|
|
||||||
key: str,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Delete a file identified by a bucket and key.
|
|
||||||
|
|
||||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
|
||||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,8 @@ class StackRun(Subcommand):
|
||||||
"config",
|
"config",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="?", # Make it optional
|
nargs="?", # Make it optional
|
||||||
help="Path to config file to use for the run. Required for venv and conda environments.",
|
metavar="config | template",
|
||||||
|
help="Path to config file to use for the run or name of known template (`llama stack list` for a list).",
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--port",
|
"--port",
|
||||||
|
|
@ -59,7 +60,7 @@ class StackRun(Subcommand):
|
||||||
"--image-type",
|
"--image-type",
|
||||||
type=str,
|
type=str,
|
||||||
help="Image Type used during the build. This can be either conda or container or venv.",
|
help="Image Type used during the build. This can be either conda or container or venv.",
|
||||||
choices=[e.value for e in ImageType],
|
choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value],
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--enable-ui",
|
"--enable-ui",
|
||||||
|
|
@ -154,6 +155,9 @@ class StackRun(Subcommand):
|
||||||
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
|
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
|
||||||
if callable(getattr(args, arg)):
|
if callable(getattr(args, arg)):
|
||||||
continue
|
continue
|
||||||
|
if arg == "config" and template_name:
|
||||||
|
server_args.config = str(config_file)
|
||||||
|
else:
|
||||||
setattr(server_args, arg, getattr(args, arg))
|
setattr(server_args, arg, getattr(args, arg))
|
||||||
|
|
||||||
# Run the server
|
# Run the server
|
||||||
|
|
|
||||||
|
|
@ -1,86 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__, category="core")
|
|
||||||
|
|
||||||
|
|
||||||
def check_access(
|
|
||||||
obj_identifier: str,
|
|
||||||
obj_attributes: AccessAttributes | None,
|
|
||||||
user_attributes: dict[str, Any] | None = None,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if the current user has access to the given object, based on access attributes.
|
|
||||||
|
|
||||||
Access control algorithm:
|
|
||||||
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
|
|
||||||
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
|
|
||||||
3. For each attribute category in the resource's access_attributes:
|
|
||||||
a. If the user lacks that category, access is DENIED
|
|
||||||
b. If the user has the category but none of the required values, access is DENIED
|
|
||||||
c. If the user has at least one matching value in each required category, access is GRANTED
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Resource requires:
|
|
||||||
access_attributes = AccessAttributes(
|
|
||||||
roles=["admin", "data-scientist"],
|
|
||||||
teams=["ml-team"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# User has:
|
|
||||||
user_attributes = {
|
|
||||||
"roles": ["data-scientist", "engineer"],
|
|
||||||
"teams": ["ml-team", "infra-team"],
|
|
||||||
"projects": ["llama-3"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Result: Access GRANTED
|
|
||||||
# - User has the "data-scientist" role (matches one of the required roles)
|
|
||||||
# - AND user is part of the "ml-team" (matches the required team)
|
|
||||||
# - The extra "projects" attribute is ignored
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj_identifier: The identifier of the resource object to check access for
|
|
||||||
obj_attributes: The access attributes of the resource object
|
|
||||||
user_attributes: The attributes of the current user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if access is granted, False if denied
|
|
||||||
"""
|
|
||||||
# If object has no access attributes, allow access by default
|
|
||||||
if not obj_attributes:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# If no user attributes, deny access to objects with access control
|
|
||||||
if not user_attributes:
|
|
||||||
return False
|
|
||||||
|
|
||||||
dict_attribs = obj_attributes.model_dump(exclude_none=True)
|
|
||||||
if not dict_attribs:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check each attribute category (requires ALL categories to match)
|
|
||||||
# TODO: formalize this into a proper ABAC policy
|
|
||||||
for attr_key, required_values in dict_attribs.items():
|
|
||||||
user_values = user_attributes.get(attr_key, [])
|
|
||||||
|
|
||||||
if not user_values:
|
|
||||||
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not any(val in user_values for val in required_values):
|
|
||||||
logger.debug(
|
|
||||||
f"Access denied to {obj_identifier}: "
|
|
||||||
f"no match for attribute '{attr_key}', required one of {required_values}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
logger.debug(f"Access granted to {obj_identifier}")
|
|
||||||
return True
|
|
||||||
|
|
@ -3,5 +3,3 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .verification import get_distribution_template # noqa: F401
|
|
||||||
109
llama_stack/distribution/access_control/access_control.py
Normal file
109
llama_stack/distribution/access_control/access_control.py
Normal file
|
|
@ -0,0 +1,109 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import User
|
||||||
|
|
||||||
|
from .conditions import (
|
||||||
|
Condition,
|
||||||
|
ProtectedResource,
|
||||||
|
parse_conditions,
|
||||||
|
)
|
||||||
|
from .datatypes import (
|
||||||
|
AccessRule,
|
||||||
|
Action,
|
||||||
|
Scope,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def matches_resource(resource_scope: str, actual_resource: str) -> bool:
|
||||||
|
if resource_scope == actual_resource:
|
||||||
|
return True
|
||||||
|
return resource_scope.endswith("::*") and actual_resource.startswith(resource_scope[:-1])
|
||||||
|
|
||||||
|
|
||||||
|
def matches_scope(
|
||||||
|
scope: Scope,
|
||||||
|
action: Action,
|
||||||
|
resource: str,
|
||||||
|
user: str | None,
|
||||||
|
) -> bool:
|
||||||
|
if scope.resource and not matches_resource(scope.resource, resource):
|
||||||
|
return False
|
||||||
|
if scope.principal and scope.principal != user:
|
||||||
|
return False
|
||||||
|
return action in scope.actions
|
||||||
|
|
||||||
|
|
||||||
|
def as_list(obj: Any) -> list[Any]:
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return obj
|
||||||
|
return [obj]
|
||||||
|
|
||||||
|
|
||||||
|
def matches_conditions(
|
||||||
|
conditions: list[Condition],
|
||||||
|
resource: ProtectedResource,
|
||||||
|
user: User,
|
||||||
|
) -> bool:
|
||||||
|
for condition in conditions:
|
||||||
|
# must match all conditions
|
||||||
|
if not condition.matches(resource, user):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def default_policy() -> list[AccessRule]:
|
||||||
|
# for backwards compatibility, if no rules are provided, assume
|
||||||
|
# full access subject to previous attribute matching rules
|
||||||
|
return [
|
||||||
|
AccessRule(
|
||||||
|
permit=Scope(actions=list(Action)),
|
||||||
|
when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def is_action_allowed(
|
||||||
|
policy: list[AccessRule],
|
||||||
|
action: Action,
|
||||||
|
resource: ProtectedResource,
|
||||||
|
user: User | None,
|
||||||
|
) -> bool:
|
||||||
|
# If user is not set, assume authentication is not enabled
|
||||||
|
if not user:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not len(policy):
|
||||||
|
policy = default_policy()
|
||||||
|
|
||||||
|
qualified_resource_id = resource.type + "::" + resource.identifier
|
||||||
|
for rule in policy:
|
||||||
|
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
|
||||||
|
if rule.when:
|
||||||
|
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||||
|
return False
|
||||||
|
elif rule.unless:
|
||||||
|
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
elif rule.permit and matches_scope(rule.permit, action, qualified_resource_id, user.principal):
|
||||||
|
if rule.when:
|
||||||
|
if matches_conditions(parse_conditions(as_list(rule.when)), resource, user):
|
||||||
|
return True
|
||||||
|
elif rule.unless:
|
||||||
|
if not matches_conditions(parse_conditions(as_list(rule.unless)), resource, user):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
# assume access is denied unless we find a rule that permits access
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class AccessDeniedError(RuntimeError):
|
||||||
|
pass
|
||||||
129
llama_stack/distribution/access_control/conditions.py
Normal file
129
llama_stack/distribution/access_control/conditions.py
Normal file
|
|
@ -0,0 +1,129 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class User(Protocol):
|
||||||
|
principal: str
|
||||||
|
attributes: dict[str, list[str]] | None
|
||||||
|
|
||||||
|
|
||||||
|
class ProtectedResource(Protocol):
|
||||||
|
type: str
|
||||||
|
identifier: str
|
||||||
|
owner: User
|
||||||
|
|
||||||
|
|
||||||
|
class Condition(Protocol):
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
class UserInOwnersList:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def owners_values(self, resource: ProtectedResource) -> list[str] | None:
|
||||||
|
if (
|
||||||
|
hasattr(resource, "owner")
|
||||||
|
and resource.owner
|
||||||
|
and resource.owner.attributes
|
||||||
|
and self.name in resource.owner.attributes
|
||||||
|
):
|
||||||
|
return resource.owner.attributes[self.name]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
required = self.owners_values(resource)
|
||||||
|
if not required:
|
||||||
|
return True
|
||||||
|
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
|
||||||
|
return False
|
||||||
|
user_values = user.attributes[self.name]
|
||||||
|
for value in required:
|
||||||
|
if value in user_values:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"user in owners {self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserNotInOwnersList(UserInOwnersList):
|
||||||
|
def __init__(self, name: str):
|
||||||
|
super().__init__(name)
|
||||||
|
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
return not super().matches(resource, user)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"user not in owners {self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserWithValueInList:
|
||||||
|
def __init__(self, name: str, value: str):
|
||||||
|
self.name = name
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
if user.attributes and self.name in user.attributes:
|
||||||
|
return self.value in user.attributes[self.name]
|
||||||
|
print(f"User does not have {self.value} in {self.name}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"user with {self.value} in {self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserWithValueNotInList(UserWithValueInList):
|
||||||
|
def __init__(self, name: str, value: str):
|
||||||
|
super().__init__(name, value)
|
||||||
|
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
return not super().matches(resource, user)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"user with {self.value} not in {self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserIsOwner:
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
return resource.owner.principal == user.principal if resource.owner else False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "user is owner"
|
||||||
|
|
||||||
|
|
||||||
|
class UserIsNotOwner:
|
||||||
|
def matches(self, resource: ProtectedResource, user: User) -> bool:
|
||||||
|
return not resource.owner or resource.owner.principal != user.principal
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "user is not owner"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_condition(condition: str) -> Condition:
|
||||||
|
words = condition.split()
|
||||||
|
match words:
|
||||||
|
case ["user", "is", "owner"]:
|
||||||
|
return UserIsOwner()
|
||||||
|
case ["user", "is", "not", "owner"]:
|
||||||
|
return UserIsNotOwner()
|
||||||
|
case ["user", "with", value, "in", name]:
|
||||||
|
return UserWithValueInList(name, value)
|
||||||
|
case ["user", "with", value, "not", "in", name]:
|
||||||
|
return UserWithValueNotInList(name, value)
|
||||||
|
case ["user", "in", "owners", name]:
|
||||||
|
return UserInOwnersList(name)
|
||||||
|
case ["user", "not", "in", "owners", name]:
|
||||||
|
return UserNotInOwnersList(name)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid condition: {condition}")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_conditions(conditions: list[str]) -> list[Condition]:
|
||||||
|
return [parse_condition(c) for c in conditions]
|
||||||
107
llama_stack/distribution/access_control/datatypes.py
Normal file
107
llama_stack/distribution/access_control/datatypes.py
Normal file
|
|
@ -0,0 +1,107 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from .conditions import parse_conditions
|
||||||
|
|
||||||
|
|
||||||
|
class Action(str, Enum):
|
||||||
|
CREATE = "create"
|
||||||
|
READ = "read"
|
||||||
|
UPDATE = "update"
|
||||||
|
DELETE = "delete"
|
||||||
|
|
||||||
|
|
||||||
|
class Scope(BaseModel):
|
||||||
|
principal: str | None = None
|
||||||
|
actions: Action | list[Action]
|
||||||
|
resource: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _mutually_exclusive(obj, a: str, b: str):
|
||||||
|
if getattr(obj, a) and getattr(obj, b):
|
||||||
|
raise ValueError(f"{a} and {b} are mutually exclusive")
|
||||||
|
|
||||||
|
|
||||||
|
def _require_one_of(obj, a: str, b: str):
|
||||||
|
if not getattr(obj, a) and not getattr(obj, b):
|
||||||
|
raise ValueError(f"on of {a} or {b} is required")
|
||||||
|
|
||||||
|
|
||||||
|
class AccessRule(BaseModel):
|
||||||
|
"""Access rule based loosely on cedar policy language
|
||||||
|
|
||||||
|
A rule defines a list of action either to permit or to forbid. It may specify a
|
||||||
|
principal or a resource that must match for the rule to take effect. The resource
|
||||||
|
to match should be specified in the form of a type qualified identifier, e.g.
|
||||||
|
model::my-model or vector_db::some-db, or a wildcard for all resources of a type,
|
||||||
|
e.g. model::*. If the principal or resource are not specified, they will match all
|
||||||
|
requests.
|
||||||
|
|
||||||
|
A rule may also specify a condition, either a 'when' or an 'unless', with additional
|
||||||
|
constraints as to where the rule applies. The constraints supported at present are:
|
||||||
|
|
||||||
|
- 'user with <attr-value> in <attr-name>'
|
||||||
|
- 'user with <attr-value> not in <attr-name>'
|
||||||
|
- 'user is owner'
|
||||||
|
- 'user is not owner'
|
||||||
|
- 'user in owners <attr-name>'
|
||||||
|
- 'user not in owners <attr-name>'
|
||||||
|
|
||||||
|
Rules are tested in order to find a match. If a match is found, the request is
|
||||||
|
permitted or forbidden depending on the type of rule. If no match is found, the
|
||||||
|
request is denied. If no rules are specified, a rule that allows any action as
|
||||||
|
long as the resource attributes match the user attributes is added
|
||||||
|
(i.e. the previous behaviour is the default).
|
||||||
|
|
||||||
|
Some examples in yaml:
|
||||||
|
|
||||||
|
- permit:
|
||||||
|
principal: user-1
|
||||||
|
actions: [create, read, delete]
|
||||||
|
resource: model::*
|
||||||
|
description: user-1 has full access to all models
|
||||||
|
- permit:
|
||||||
|
principal: user-2
|
||||||
|
actions: [read]
|
||||||
|
resource: model::model-1
|
||||||
|
description: user-2 has read access to model-1 only
|
||||||
|
- permit:
|
||||||
|
actions: [read]
|
||||||
|
when: user in owner teams
|
||||||
|
description: any user has read access to any resource created by a member of their team
|
||||||
|
- forbid:
|
||||||
|
actions: [create, read, delete]
|
||||||
|
resource: vector_db::*
|
||||||
|
unless: user with admin in roles
|
||||||
|
description: only user with admin role can use vector_db resources
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
permit: Scope | None = None
|
||||||
|
forbid: Scope | None = None
|
||||||
|
when: str | list[str] | None = None
|
||||||
|
unless: str | list[str] | None = None
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_rule_format(self) -> Self:
|
||||||
|
_require_one_of(self, "permit", "forbid")
|
||||||
|
_mutually_exclusive(self, "permit", "forbid")
|
||||||
|
_mutually_exclusive(self, "when", "unless")
|
||||||
|
if isinstance(self.when, list):
|
||||||
|
parse_conditions(self.when)
|
||||||
|
elif self.when:
|
||||||
|
parse_conditions([self.when])
|
||||||
|
if isinstance(self.unless, list):
|
||||||
|
parse_conditions(self.unless)
|
||||||
|
elif self.unless:
|
||||||
|
parse_conditions([self.unless])
|
||||||
|
return self
|
||||||
|
|
@ -29,6 +29,8 @@ SERVER_DEPENDENCIES = [
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput
|
||||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.distribution.access_control.datatypes import AccessRule
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
||||||
|
|
@ -35,126 +36,66 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
RoutingKey = str | list[str]
|
RoutingKey = str | list[str]
|
||||||
|
|
||||||
|
|
||||||
class AccessAttributes(BaseModel):
|
class User(BaseModel):
|
||||||
"""Structured representation of user attributes for access control.
|
principal: str
|
||||||
|
# further attributes that may be used for access control decisions
|
||||||
|
attributes: dict[str, list[str]] | None = None
|
||||||
|
|
||||||
This model defines a structured approach to representing user attributes
|
def __init__(self, principal: str, attributes: dict[str, list[str]] | None):
|
||||||
with common standard categories for access control.
|
super().__init__(principal=principal, attributes=attributes)
|
||||||
|
|
||||||
Standard attribute categories include:
|
|
||||||
- roles: Role-based attributes (e.g., admin, data-scientist)
|
|
||||||
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
|
||||||
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
|
||||||
- namespaces: Namespace-based access control for resource isolation
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Standard attribute categories - the minimal set we need now
|
|
||||||
roles: list[str] | None = Field(
|
|
||||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
|
||||||
)
|
|
||||||
|
|
||||||
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
|
||||||
|
|
||||||
projects: list[str] | None = Field(
|
|
||||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
|
||||||
)
|
|
||||||
|
|
||||||
namespaces: list[str] | None = Field(
|
|
||||||
default=None, description="Namespace-based access control for resource isolation"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ResourceWithACL(Resource):
|
class ResourceWithOwner(Resource):
|
||||||
"""Extension of Resource that adds attribute-based access control capabilities.
|
"""Extension of Resource that adds an optional owner, i.e. the user that created the
|
||||||
|
resource. This can be used to constrain access to the resource."""
|
||||||
|
|
||||||
This class adds an optional access_attributes field that allows fine-grained control
|
owner: User | None = None
|
||||||
over which users can access each resource. When attributes are defined, a user must have
|
|
||||||
matching attributes to access the resource.
|
|
||||||
|
|
||||||
Attribute Matching Algorithm:
|
|
||||||
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
|
|
||||||
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
|
|
||||||
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
|
|
||||||
4. Within each category, ANY value match is sufficient (OR relationship within a category)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
# Resource visible to everyone (no access control)
|
|
||||||
model = Model(identifier="llama-2", ...)
|
|
||||||
|
|
||||||
# Resource visible only to admins
|
|
||||||
model = Model(
|
|
||||||
identifier="gpt-4",
|
|
||||||
access_attributes=AccessAttributes(roles=["admin"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resource visible to data scientists on the ML team
|
|
||||||
model = Model(
|
|
||||||
identifier="private-model",
|
|
||||||
access_attributes=AccessAttributes(
|
|
||||||
roles=["data-scientist", "researcher"],
|
|
||||||
teams=["ml-team"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# ^ User must have at least one of the roles AND be on the ml-team
|
|
||||||
|
|
||||||
# Resource visible to users with specific project access
|
|
||||||
vector_db = VectorDB(
|
|
||||||
identifier="customer-embeddings",
|
|
||||||
access_attributes=AccessAttributes(
|
|
||||||
projects=["customer-insights"],
|
|
||||||
namespaces=["confidential"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
|
||||||
"""
|
|
||||||
|
|
||||||
access_attributes: AccessAttributes | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# Use the extended Resource for all routable objects
|
# Use the extended Resource for all routable objects
|
||||||
class ModelWithACL(Model, ResourceWithACL):
|
class ModelWithOwner(Model, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ShieldWithACL(Shield, ResourceWithACL):
|
class ShieldWithOwner(Shield, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class VectorDBWithACL(VectorDB, ResourceWithACL):
|
class VectorDBWithOwner(VectorDB, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DatasetWithACL(Dataset, ResourceWithACL):
|
class DatasetWithOwner(Dataset, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
|
class ScoringFnWithOwner(ScoringFn, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkWithACL(Benchmark, ResourceWithACL):
|
class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolWithACL(Tool, ResourceWithACL):
|
class ToolWithOwner(Tool, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
||||||
|
|
||||||
RoutableObjectWithProvider = Annotated[
|
RoutableObjectWithProvider = Annotated[
|
||||||
ModelWithACL
|
ModelWithOwner
|
||||||
| ShieldWithACL
|
| ShieldWithOwner
|
||||||
| VectorDBWithACL
|
| VectorDBWithOwner
|
||||||
| DatasetWithACL
|
| DatasetWithOwner
|
||||||
| ScoringFnWithACL
|
| ScoringFnWithOwner
|
||||||
| BenchmarkWithACL
|
| BenchmarkWithOwner
|
||||||
| ToolWithACL
|
| ToolWithOwner
|
||||||
| ToolGroupWithACL,
|
| ToolGroupWithOwner,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -234,6 +175,7 @@ class AuthenticationConfig(BaseModel):
|
||||||
...,
|
...,
|
||||||
description="Provider-specific configuration",
|
description="Provider-specific configuration",
|
||||||
)
|
)
|
||||||
|
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationRequiredError(Exception):
|
class AuthenticationRequiredError(Exception):
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ import logging
|
||||||
from contextlib import AbstractContextManager
|
from contextlib import AbstractContextManager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import User
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -21,12 +23,10 @@ PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||||
class RequestProviderDataContext(AbstractContextManager):
|
class RequestProviderDataContext(AbstractContextManager):
|
||||||
"""Context manager for request provider data"""
|
"""Context manager for request provider data"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, provider_data: dict[str, Any] | None = None, user: User | None = None):
|
||||||
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
|
|
||||||
):
|
|
||||||
self.provider_data = provider_data or {}
|
self.provider_data = provider_data or {}
|
||||||
if auth_attributes:
|
if user:
|
||||||
self.provider_data["__auth_attributes"] = auth_attributes
|
self.provider_data["__authenticated_user"] = user
|
||||||
|
|
||||||
self.token = None
|
self.token = None
|
||||||
|
|
||||||
|
|
@ -95,9 +95,9 @@ def request_provider_data_context(
|
||||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||||
|
|
||||||
|
|
||||||
def get_auth_attributes() -> dict[str, list[str]] | None:
|
def get_authenticated_user() -> User | None:
|
||||||
"""Helper to retrieve auth attributes from the provider data context"""
|
"""Helper to retrieve auth attributes from the provider data context"""
|
||||||
provider_data = PROVIDER_DATA_VAR.get()
|
provider_data = PROVIDER_DATA_VAR.get()
|
||||||
if not provider_data:
|
if not provider_data:
|
||||||
return None
|
return None
|
||||||
return provider_data.get("__auth_attributes")
|
return provider_data.get("__authenticated_user")
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from llama_stack.apis.vector_dbs import VectorDBs
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.distribution.client import get_client_impl
|
from llama_stack.distribution.client import get_client_impl
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AccessRule,
|
||||||
AutoRoutedProviderSpec,
|
AutoRoutedProviderSpec,
|
||||||
Provider,
|
Provider,
|
||||||
RoutingTableProviderSpec,
|
RoutingTableProviderSpec,
|
||||||
|
|
@ -118,6 +119,7 @@ async def resolve_impls(
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
provider_registry: ProviderRegistry,
|
provider_registry: ProviderRegistry,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
|
policy: list[AccessRule],
|
||||||
) -> dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Resolves provider implementations by:
|
Resolves provider implementations by:
|
||||||
|
|
@ -140,7 +142,7 @@ async def resolve_impls(
|
||||||
|
|
||||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||||
|
|
||||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
|
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config, policy)
|
||||||
|
|
||||||
|
|
||||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||||
|
|
@ -247,6 +249,7 @@ async def instantiate_providers(
|
||||||
router_apis: set[Api],
|
router_apis: set[Api],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
|
policy: list[AccessRule],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Instantiates providers asynchronously while managing dependencies."""
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
impls: dict[Api, Any] = {}
|
impls: dict[Api, Any] = {}
|
||||||
|
|
@ -261,7 +264,7 @@ async def instantiate_providers(
|
||||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||||
|
|
||||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
|
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config, policy)
|
||||||
|
|
||||||
if api_str.startswith("inner-"):
|
if api_str.startswith("inner-"):
|
||||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||||
|
|
@ -312,6 +315,7 @@ async def instantiate_provider(
|
||||||
inner_impls: dict[str, Any],
|
inner_impls: dict[str, Any],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
|
policy: list[AccessRule],
|
||||||
):
|
):
|
||||||
provider_spec = provider.spec
|
provider_spec = provider.spec
|
||||||
if not hasattr(provider_spec, "module"):
|
if not hasattr(provider_spec, "module"):
|
||||||
|
|
@ -336,13 +340,15 @@ async def instantiate_provider(
|
||||||
method = "get_routing_table_impl"
|
method = "get_routing_table_impl"
|
||||||
|
|
||||||
config = None
|
config = None
|
||||||
args = [provider_spec.api, inner_impls, deps, dist_registry]
|
args = [provider_spec.api, inner_impls, deps, dist_registry, policy]
|
||||||
else:
|
else:
|
||||||
method = "get_provider_impl"
|
method = "get_provider_impl"
|
||||||
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
config = config_type(**provider.config)
|
config = config_type(**provider.config)
|
||||||
args = [config, deps]
|
args = [config, deps]
|
||||||
|
if "policy" in inspect.signature(getattr(module, method)).parameters:
|
||||||
|
args.append(policy)
|
||||||
|
|
||||||
fn = getattr(module, method)
|
fn = getattr(module, method)
|
||||||
impl = await fn(*args)
|
impl = await fn(*args)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
from llama_stack.distribution.datatypes import AccessRule, RoutedProtocol
|
||||||
from llama_stack.distribution.stack import StackRunConfig
|
from llama_stack.distribution.stack import StackRunConfig
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
@ -18,6 +18,7 @@ async def get_routing_table_impl(
|
||||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||||
_deps,
|
_deps,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
|
policy: list[AccessRule],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
||||||
from ..routing_tables.datasets import DatasetsRoutingTable
|
from ..routing_tables.datasets import DatasetsRoutingTable
|
||||||
|
|
@ -40,7 +41,7 @@ async def get_routing_table_impl(
|
||||||
if api.value not in api_to_tables:
|
if api.value not in api_to_tables:
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry, policy)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
BenchmarkWithACL,
|
BenchmarkWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
|
@ -47,7 +47,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
)
|
)
|
||||||
if provider_benchmark_id is None:
|
if provider_benchmark_id is None:
|
||||||
provider_benchmark_id = benchmark_id
|
provider_benchmark_id = benchmark_id
|
||||||
benchmark = BenchmarkWithACL(
|
benchmark = BenchmarkWithOwner(
|
||||||
identifier=benchmark_id,
|
identifier=benchmark_id,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
|
|
|
||||||
|
|
@ -8,14 +8,14 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
from llama_stack.distribution.access_control import check_access
|
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
AccessAttributes,
|
AccessRule,
|
||||||
RoutableObject,
|
RoutableObject,
|
||||||
RoutableObjectWithProvider,
|
RoutableObjectWithProvider,
|
||||||
RoutedProtocol,
|
RoutedProtocol,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
from llama_stack.distribution.request_headers import get_authenticated_user
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
@ -73,9 +73,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
self,
|
self,
|
||||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
|
policy: list[AccessRule],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.impls_by_provider_id = impls_by_provider_id
|
self.impls_by_provider_id = impls_by_provider_id
|
||||||
self.dist_registry = dist_registry
|
self.dist_registry = dist_registry
|
||||||
|
self.policy = policy
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||||
|
|
@ -166,13 +168,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check if user has permission to access this object
|
# Check if user has permission to access this object
|
||||||
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
|
||||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
logger.debug(f"Access denied to {type} '{identifier}'")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
|
if not is_action_allowed(self.policy, "delete", obj, get_authenticated_user()):
|
||||||
|
raise AccessDeniedError()
|
||||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||||
|
|
||||||
|
|
@ -187,11 +191,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
# If object supports access control but no attributes set, use creator's attributes
|
# If object supports access control but no attributes set, use creator's attributes
|
||||||
if not obj.access_attributes:
|
creator = get_authenticated_user()
|
||||||
creator_attributes = get_auth_attributes()
|
if not is_action_allowed(self.policy, "create", obj, creator):
|
||||||
if creator_attributes:
|
raise AccessDeniedError()
|
||||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
if creator:
|
||||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
obj.owner = creator
|
||||||
|
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
|
||||||
|
|
||||||
registered_obj = await register_object_with_provider(obj, p)
|
registered_obj = await register_object_with_provider(obj, p)
|
||||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||||
|
|
@ -210,9 +215,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
# Apply attribute-based access control filtering
|
# Apply attribute-based access control filtering
|
||||||
if filtered_objs:
|
if filtered_objs:
|
||||||
filtered_objs = [
|
filtered_objs = [
|
||||||
obj
|
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
|
||||||
for obj in filtered_objs
|
|
||||||
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return filtered_objs
|
return filtered_objs
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from llama_stack.apis.datasets import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
DatasetWithACL,
|
DatasetWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
|
@ -74,7 +74,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
dataset = DatasetWithACL(
|
dataset = DatasetWithOwner(
|
||||||
identifier=dataset_id,
|
identifier=dataset_id,
|
||||||
provider_resource_id=provider_dataset_id,
|
provider_resource_id=provider_dataset_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
ModelWithACL,
|
ModelWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
|
@ -65,7 +65,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
model_type = ModelType.llm
|
model_type = ModelType.llm
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
model = ModelWithACL(
|
model = ModelWithOwner(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.scoring_functions import (
|
||||||
ScoringFunctions,
|
ScoringFunctions,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
ScoringFnWithACL,
|
ScoringFnWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
|
@ -50,7 +50,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
)
|
)
|
||||||
scoring_fn = ScoringFnWithACL(
|
scoring_fn = ScoringFnWithOwner(
|
||||||
identifier=scoring_fn_id,
|
identifier=scoring_fn_id,
|
||||||
description=description,
|
description=description,
|
||||||
return_type=return_type,
|
return_type=return_type,
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from typing import Any
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
ShieldWithACL,
|
ShieldWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
|
@ -47,7 +47,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
)
|
)
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
params = {}
|
||||||
shield = ShieldWithACL(
|
shield = ShieldWithOwner(
|
||||||
identifier=shield_id,
|
identifier=shield_id,
|
||||||
provider_resource_id=provider_shield_id,
|
provider_resource_id=provider_shield_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||||
from llama_stack.distribution.datatypes import ToolGroupWithACL
|
from llama_stack.distribution.datatypes import ToolGroupWithOwner
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
@ -106,7 +106,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
mcp_endpoint: URL | None = None,
|
mcp_endpoint: URL | None = None,
|
||||||
args: dict[str, Any] | None = None,
|
args: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
toolgroup = ToolGroupWithACL(
|
toolgroup = ToolGroupWithOwner(
|
||||||
identifier=toolgroup_id,
|
identifier=toolgroup_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_resource_id=toolgroup_id,
|
provider_resource_id=toolgroup_id,
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
VectorDBWithACL,
|
VectorDBWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
|
@ -63,7 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
"embedding_model": embedding_model,
|
"embedding_model": embedding_model,
|
||||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
}
|
}
|
||||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||||
await self.register_object(vector_db)
|
await self.register_object(vector_db)
|
||||||
return vector_db
|
return vector_db
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -105,24 +105,16 @@ class AuthenticationMiddleware:
|
||||||
logger.exception("Error during authentication")
|
logger.exception("Error during authentication")
|
||||||
return await self._send_auth_error(send, "Authentication service error")
|
return await self._send_auth_error(send, "Authentication service error")
|
||||||
|
|
||||||
# Store attributes in request scope for access control
|
|
||||||
if validation_result.access_attributes:
|
|
||||||
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
|
|
||||||
else:
|
|
||||||
logger.warning("No access attributes, setting namespace to token by default")
|
|
||||||
user_attributes = {
|
|
||||||
"roles": [token],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||||
# can identify the requester and enforce per-client rate limits.
|
# can identify the requester and enforce per-client rate limits.
|
||||||
scope["authenticated_client_id"] = token
|
scope["authenticated_client_id"] = token
|
||||||
|
|
||||||
# Store attributes in request scope
|
# Store attributes in request scope
|
||||||
scope["user_attributes"] = user_attributes
|
|
||||||
scope["principal"] = validation_result.principal
|
scope["principal"] = validation_result.principal
|
||||||
|
if validation_result.attributes:
|
||||||
|
scope["user_attributes"] = validation_result.attributes
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
|
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
|
||||||
|
|
@ -16,43 +16,18 @@ from jose import jwt
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
|
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
class TokenValidationResult(BaseModel):
|
class AuthResponse(BaseModel):
|
||||||
principal: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="The principal (username or persistent identifier) of the authenticated user",
|
|
||||||
)
|
|
||||||
access_attributes: AccessAttributes | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="""
|
|
||||||
Structured user attributes for attribute-based access control.
|
|
||||||
|
|
||||||
These attributes determine which resources the user can access.
|
|
||||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
|
||||||
Each attribute category contains a list of values that the user has for that category.
|
|
||||||
During access control checks, these values are compared against resource requirements.
|
|
||||||
|
|
||||||
Example with standard categories:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"roles": ["admin", "data-scientist"],
|
|
||||||
"teams": ["ml-team"],
|
|
||||||
"projects": ["llama-3"],
|
|
||||||
"namespaces": ["research"]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(TokenValidationResult):
|
|
||||||
"""The format of the authentication response from the auth endpoint."""
|
"""The format of the authentication response from the auth endpoint."""
|
||||||
|
|
||||||
|
principal: str
|
||||||
|
# further attributes that may be used for access control decisions
|
||||||
|
attributes: dict[str, list[str]] | None = None
|
||||||
message: str | None = Field(
|
message: str | None = Field(
|
||||||
default=None, description="Optional message providing additional context about the authentication result."
|
default=None, description="Optional message providing additional context about the authentication result."
|
||||||
)
|
)
|
||||||
|
|
@ -78,7 +53,7 @@ class AuthProvider(ABC):
|
||||||
"""Abstract base class for authentication providers."""
|
"""Abstract base class for authentication providers."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
"""Validate a token and return access attributes."""
|
"""Validate a token and return access attributes."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -88,10 +63,10 @@ class AuthProvider(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||||
attributes = AccessAttributes()
|
attributes: dict[str, list[str]] = {}
|
||||||
for claim_key, attribute_key in mapping.items():
|
for claim_key, attribute_key in mapping.items():
|
||||||
if claim_key not in claims or not hasattr(attributes, attribute_key):
|
if claim_key not in claims:
|
||||||
continue
|
continue
|
||||||
claim = claims[claim_key]
|
claim = claims[claim_key]
|
||||||
if isinstance(claim, list):
|
if isinstance(claim, list):
|
||||||
|
|
@ -99,11 +74,10 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
||||||
else:
|
else:
|
||||||
values = claim.split()
|
values = claim.split()
|
||||||
|
|
||||||
current = getattr(attributes, attribute_key)
|
if attribute_key in attributes:
|
||||||
if current:
|
attributes[attribute_key].extend(values)
|
||||||
current.extend(values)
|
|
||||||
else:
|
else:
|
||||||
setattr(attributes, attribute_key, values)
|
attributes[attribute_key] = values
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -145,8 +119,6 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||||
for key, value in v.items():
|
for key, value in v.items():
|
||||||
if not value:
|
if not value:
|
||||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||||
if value not in AccessAttributes.model_fields:
|
|
||||||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|
@ -171,14 +143,14 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
self._jwks: dict[str, str] = {}
|
self._jwks: dict[str, str] = {}
|
||||||
self._jwks_lock = Lock()
|
self._jwks_lock = Lock()
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
if self.config.jwks:
|
if self.config.jwks:
|
||||||
return await self.validate_jwt_token(token, scope)
|
return await self.validate_jwt_token(token, scope)
|
||||||
if self.config.introspection:
|
if self.config.introspection:
|
||||||
return await self.introspect_token(token, scope)
|
return await self.introspect_token(token, scope)
|
||||||
raise ValueError("One of jwks or introspection must be configured")
|
raise ValueError("One of jwks or introspection must be configured")
|
||||||
|
|
||||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
"""Validate a token using the JWT token."""
|
"""Validate a token using the JWT token."""
|
||||||
await self._refresh_jwks()
|
await self._refresh_jwks()
|
||||||
|
|
||||||
|
|
@ -203,12 +175,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
# We should incorporate these into the access attributes.
|
# We should incorporate these into the access attributes.
|
||||||
principal = claims["sub"]
|
principal = claims["sub"]
|
||||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||||
return TokenValidationResult(
|
return User(
|
||||||
principal=principal,
|
principal=principal,
|
||||||
access_attributes=access_attributes,
|
attributes=access_attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
"""Validate a token using token introspection as defined by RFC 7662."""
|
"""Validate a token using token introspection as defined by RFC 7662."""
|
||||||
form = {
|
form = {
|
||||||
"token": token,
|
"token": token,
|
||||||
|
|
@ -242,9 +214,9 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
raise ValueError("Token not active")
|
raise ValueError("Token not active")
|
||||||
principal = fields["sub"] or fields["username"]
|
principal = fields["sub"] or fields["username"]
|
||||||
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
||||||
return TokenValidationResult(
|
return User(
|
||||||
principal=principal,
|
principal=principal,
|
||||||
access_attributes=access_attributes,
|
attributes=access_attributes,
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
logger.exception("Token introspection request timed out")
|
logger.exception("Token introspection request timed out")
|
||||||
|
|
@ -299,7 +271,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
"""Validate a token using the custom authentication endpoint."""
|
"""Validate a token using the custom authentication endpoint."""
|
||||||
if scope is None:
|
if scope is None:
|
||||||
scope = {}
|
scope = {}
|
||||||
|
|
@ -341,7 +313,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
try:
|
try:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
auth_response = AuthResponse(**response_data)
|
auth_response = AuthResponse(**response_data)
|
||||||
return auth_response
|
return User(auth_response.principal, auth_response.attributes)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error parsing authentication response")
|
logger.exception("Error parsing authentication response")
|
||||||
raise ValueError("Invalid authentication response format") from e
|
raise ValueError("Invalid authentication response format") from e
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from collections.abc import Callable
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from importlib.metadata import version as parse_version
|
from importlib.metadata import version as parse_version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any, get_origin
|
||||||
|
|
||||||
import rich.pretty
|
import rich.pretty
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -26,17 +26,13 @@ from aiohttp import hdrs
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request
|
from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||||
PROVIDER_DATA_VAR,
|
|
||||||
request_provider_data_context,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.server.routes import (
|
from llama_stack.distribution.server.routes import (
|
||||||
find_matching_route,
|
find_matching_route,
|
||||||
|
|
@ -217,11 +213,13 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
async def route_handler(request: Request, **kwargs):
|
async def route_handler(request: Request, **kwargs):
|
||||||
# Get auth attributes from the request scope
|
# Get auth attributes from the request scope
|
||||||
user_attributes = request.scope.get("user_attributes", {})
|
user_attributes = request.scope.get("user_attributes", {})
|
||||||
|
principal = request.scope.get("principal", "")
|
||||||
|
user = User(principal, user_attributes)
|
||||||
|
|
||||||
await log_request_pre_validation(request)
|
await log_request_pre_validation(request)
|
||||||
|
|
||||||
# Use context manager with both provider data and auth attributes
|
# Use context manager with both provider data and auth attributes
|
||||||
with request_provider_data_context(request.headers, user_attributes):
|
with request_provider_data_context(request.headers, user):
|
||||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -244,15 +242,23 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
|
|
||||||
path_params = extract_path_params(route)
|
path_params = extract_path_params(route)
|
||||||
if method == "post":
|
if method == "post":
|
||||||
# Annotate parameters that are in the path with Path(...) and others with Body(...)
|
# Annotate parameters that are in the path with Path(...) and others with Body(...),
|
||||||
new_params = [new_params[0]] + [
|
# but preserve existing File() and Form() annotations for multipart form data
|
||||||
|
new_params = (
|
||||||
|
[new_params[0]]
|
||||||
|
+ [
|
||||||
(
|
(
|
||||||
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
||||||
if param.name in path_params
|
if param.name in path_params
|
||||||
|
else (
|
||||||
|
param # Keep original annotation if it's already an Annotated type
|
||||||
|
if get_origin(param.annotation) is Annotated
|
||||||
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||||
)
|
)
|
||||||
|
)
|
||||||
for param in new_params[1:]
|
for param in new_params[1:]
|
||||||
]
|
]
|
||||||
|
)
|
||||||
|
|
||||||
route_handler.__signature__ = sig.replace(parameters=new_params)
|
route_handler.__signature__ = sig.replace(parameters=new_params)
|
||||||
|
|
||||||
|
|
@ -472,17 +478,6 @@ def main(args: argparse.Namespace | None = None):
|
||||||
window_seconds=window_seconds,
|
window_seconds=window_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- CORS middleware for local development ---
|
|
||||||
# TODO: move to reverse proxy
|
|
||||||
ui_port = os.environ.get("LLAMA_STACK_UI_PORT", 8322)
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=[f"http://localhost:{ui_port}"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
except InvalidProviderError as e:
|
except InvalidProviderError as e:
|
||||||
|
|
|
||||||
|
|
@ -223,7 +223,10 @@ async def construct_stack(
|
||||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||||
) -> dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
||||||
|
impls = await resolve_impls(
|
||||||
|
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
|
||||||
|
)
|
||||||
|
|
||||||
# Add internal implementations after all other providers are resolved
|
# Add internal implementations after all other providers are resolved
|
||||||
add_internal_implementations(impls, run_config)
|
add_internal_implementations(impls, run_config)
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
|
||||||
CONTAINER_OPTS=${CONTAINER_OPTS:-}
|
|
||||||
LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
|
|
||||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
PYPI_VERSION=${PYPI_VERSION:-}
|
PYPI_VERSION=${PYPI_VERSION:-}
|
||||||
VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
||||||
|
|
@ -132,63 +128,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||||
$env_vars \
|
$env_vars \
|
||||||
$other_args
|
$other_args
|
||||||
elif [[ "$env_type" == "container" ]]; then
|
elif [[ "$env_type" == "container" ]]; then
|
||||||
set -x
|
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
|
||||||
|
echo -e "Please refer to the documentation for more information: https://llama-stack.readthedocs.io/en/latest/distributions/building_distro.html#llama-stack-build"
|
||||||
# Check if container command is available
|
|
||||||
if ! is_command_available $CONTAINER_BINARY; then
|
|
||||||
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
|
||||||
|
|
||||||
if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then
|
|
||||||
# Disable SELinux labels
|
|
||||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
|
||||||
fi
|
|
||||||
|
|
||||||
mounts=""
|
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):/app/llama-stack-source"
|
|
||||||
fi
|
|
||||||
if [ -n "$LLAMA_CHECKPOINT_DIR" ]; then
|
|
||||||
mounts="$mounts -v $LLAMA_CHECKPOINT_DIR:/root/.llama"
|
|
||||||
CONTAINER_OPTS="$CONTAINER_OPTS --gpus=all"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -n "$PYPI_VERSION" ]; then
|
|
||||||
version_tag="$PYPI_VERSION"
|
|
||||||
elif [ -n "$LLAMA_STACK_DIR" ]; then
|
|
||||||
version_tag="dev"
|
|
||||||
elif [ -n "$TEST_PYPI_VERSION" ]; then
|
|
||||||
version_tag="test-$TEST_PYPI_VERSION"
|
|
||||||
else
|
|
||||||
if ! is_command_available jq; then
|
|
||||||
echo -e "${RED}Error: jq not found" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
URL="https://pypi.org/pypi/llama-stack/json"
|
|
||||||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Build the command with optional yaml config
|
|
||||||
cmd="$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
|
||||||
-p $port:$port \
|
|
||||||
$env_vars \
|
|
||||||
$mounts \
|
|
||||||
--env LLAMA_STACK_PORT=$port \
|
|
||||||
--entrypoint python \
|
|
||||||
$container_image:$version_tag \
|
|
||||||
-m llama_stack.distribution.server.server"
|
|
||||||
|
|
||||||
# Add yaml config if provided, otherwise use default
|
|
||||||
if [ -n "$yaml_config" ]; then
|
|
||||||
cmd="$cmd -v $yaml_config:/app/run.yaml --config /app/run.yaml"
|
|
||||||
else
|
|
||||||
cmd="$cmd --config /app/run.yaml"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Add any other args
|
|
||||||
cmd="$cmd $other_args"
|
|
||||||
|
|
||||||
# Execute the command
|
|
||||||
eval $cmd
|
|
||||||
fi
|
fi
|
||||||
|
|
|
||||||
|
|
@ -23,11 +23,8 @@ from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||||
|
|
||||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
env_name = ""
|
env_name = ""
|
||||||
if image_type == LlamaStackImageType.CONTAINER.value:
|
|
||||||
env_name = (
|
if image_type == LlamaStackImageType.CONDA.value:
|
||||||
f"distribution-{template_name}" if template_name else (config.container_image if config else image_name)
|
|
||||||
)
|
|
||||||
elif image_type == LlamaStackImageType.CONDA.value:
|
|
||||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||||
env_name = image_name or current_conda_env
|
env_name = image_name or current_conda_env
|
||||||
if not env_name:
|
if not env_name:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
|
||||||
from collections.abc import Collection, Iterator, Sequence, Set
|
from collections.abc import Collection, Iterator, Sequence, Set
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -14,7 +13,8 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from tiktoken.load import load_tiktoken_bpe
|
|
||||||
|
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -48,19 +48,20 @@ class Tokenizer:
|
||||||
global _INSTANCE
|
global _INSTANCE
|
||||||
|
|
||||||
if _INSTANCE is None:
|
if _INSTANCE is None:
|
||||||
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
_INSTANCE = Tokenizer(Path(__file__).parent / "tokenizer.model")
|
||||||
return _INSTANCE
|
return _INSTANCE
|
||||||
|
|
||||||
def __init__(self, model_path: str):
|
def __init__(self, model_path: Path):
|
||||||
"""
|
"""
|
||||||
Initializes the Tokenizer with a Tiktoken model.
|
Initializes the Tokenizer with a Tiktoken model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path (str): The path to the Tiktoken model file.
|
model_path (str): The path to the Tiktoken model file.
|
||||||
"""
|
"""
|
||||||
assert os.path.isfile(model_path), model_path
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(f"Tokenizer model file not found: {model_path}")
|
||||||
|
|
||||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
mergeable_ranks = load_bpe_file(model_path)
|
||||||
num_base_tokens = len(mergeable_ranks)
|
num_base_tokens = len(mergeable_ranks)
|
||||||
special_tokens = [
|
special_tokens = [
|
||||||
"<|begin_of_text|>",
|
"<|begin_of_text|>",
|
||||||
|
|
@ -83,7 +84,7 @@ class Tokenizer:
|
||||||
|
|
||||||
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||||
self.model = tiktoken.Encoding(
|
self.model = tiktoken.Encoding(
|
||||||
name=Path(model_path).name,
|
name=model_path.name,
|
||||||
pat_str=self.pat_str,
|
pat_str=self.pat_str,
|
||||||
mergeable_ranks=mergeable_ranks,
|
mergeable_ranks=mergeable_ranks,
|
||||||
special_tokens=self.special_tokens,
|
special_tokens=self.special_tokens,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
|
||||||
from collections.abc import Collection, Iterator, Sequence, Set
|
from collections.abc import Collection, Iterator, Sequence, Set
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -14,7 +13,8 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from tiktoken.load import load_tiktoken_bpe
|
|
||||||
|
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -118,19 +118,20 @@ class Tokenizer:
|
||||||
global _INSTANCE
|
global _INSTANCE
|
||||||
|
|
||||||
if _INSTANCE is None:
|
if _INSTANCE is None:
|
||||||
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
_INSTANCE = Tokenizer(Path(__file__).parent / "tokenizer.model")
|
||||||
return _INSTANCE
|
return _INSTANCE
|
||||||
|
|
||||||
def __init__(self, model_path: str):
|
def __init__(self, model_path: Path):
|
||||||
"""
|
"""
|
||||||
Initializes the Tokenizer with a Tiktoken model.
|
Initializes the Tokenizer with a Tiktoken model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path (str): The path to the Tiktoken model file.
|
model_path (Path): The path to the Tiktoken model file.
|
||||||
"""
|
"""
|
||||||
assert os.path.isfile(model_path), model_path
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(f"Tokenizer model file not found: {model_path}")
|
||||||
|
|
||||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
mergeable_ranks = load_bpe_file(model_path)
|
||||||
num_base_tokens = len(mergeable_ranks)
|
num_base_tokens = len(mergeable_ranks)
|
||||||
|
|
||||||
special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS
|
special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS
|
||||||
|
|
@ -144,7 +145,7 @@ class Tokenizer:
|
||||||
|
|
||||||
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||||
self.model = tiktoken.Encoding(
|
self.model = tiktoken.Encoding(
|
||||||
name=Path(model_path).name,
|
name=model_path.name,
|
||||||
pat_str=self.O200K_PATTERN,
|
pat_str=self.O200K_PATTERN,
|
||||||
mergeable_ranks=mergeable_ranks,
|
mergeable_ranks=mergeable_ranks,
|
||||||
special_tokens=self.special_tokens,
|
special_tokens=self.special_tokens,
|
||||||
|
|
|
||||||
40
llama_stack/models/llama/tokenizer_utils.py
Normal file
40
llama_stack/models/llama/tokenizer_utils.py
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__, "tokenizer_utils")
|
||||||
|
|
||||||
|
|
||||||
|
def load_bpe_file(model_path: Path) -> dict[bytes, int]:
|
||||||
|
"""
|
||||||
|
Load BPE file directly and return mergeable ranks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (Path): Path to the BPE model file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[bytes, int]: Dictionary mapping byte sequences to their ranks.
|
||||||
|
"""
|
||||||
|
mergeable_ranks = {}
|
||||||
|
|
||||||
|
with open(model_path, encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
for line in content.splitlines():
|
||||||
|
if not line.strip(): # Skip empty lines
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
token, rank = line.split()
|
||||||
|
mergeable_ranks[base64.b64decode(token)] = int(rank)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse line '{line}': {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return mergeable_ranks
|
||||||
|
|
@ -6,12 +6,12 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import AccessRule, Api
|
||||||
|
|
||||||
from .config import MetaReferenceAgentsImplConfig
|
from .config import MetaReferenceAgentsImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
|
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||||
from .agents import MetaReferenceAgentsImpl
|
from .agents import MetaReferenceAgentsImpl
|
||||||
|
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
|
|
@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
|
||||||
deps[Api.safety],
|
deps[Api.safety],
|
||||||
deps[Api.tool_runtime],
|
deps[Api.tool_runtime],
|
||||||
deps[Api.tool_groups],
|
deps[Api.tool_groups],
|
||||||
|
policy,
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.distribution.datatypes import AccessRule
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
|
@ -96,13 +97,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
vector_io_api: VectorIO,
|
vector_io_api: VectorIO,
|
||||||
persistence_store: KVStore,
|
persistence_store: KVStore,
|
||||||
created_at: str,
|
created_at: str,
|
||||||
|
policy: list[AccessRule],
|
||||||
):
|
):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.vector_io_api = vector_io_api
|
self.vector_io_api = vector_io_api
|
||||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
self.storage = AgentPersistence(agent_id, persistence_store, policy)
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
self.created_at = created_at
|
self.created_at = created_at
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from llama_stack.apis.agents import (
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
|
|
@ -40,6 +41,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.distribution.datatypes import AccessRule
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
from llama_stack.providers.utils.pagination import paginate_records
|
from llama_stack.providers.utils.pagination import paginate_records
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
|
@ -61,6 +63,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
|
policy: list[AccessRule],
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
|
@ -71,6 +74,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
||||||
self.in_memory_store = InmemoryKVStoreImpl()
|
self.in_memory_store = InmemoryKVStoreImpl()
|
||||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||||
|
self.policy = policy
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||||
|
|
@ -129,6 +133,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||||
),
|
),
|
||||||
created_at=agent_info.created_at,
|
created_at=agent_info.created_at,
|
||||||
|
policy=self.policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agent_session(
|
async def create_agent_session(
|
||||||
|
|
@ -324,10 +329,12 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
|
text: OpenAIResponseText | None = None,
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
|
max_infer_iters: int | None = 10,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
return await self.openai_responses_impl.create_openai_response(
|
return await self.openai_responses_impl.create_openai_response(
|
||||||
input, model, instructions, previous_response_id, store, stream, temperature, tools
|
input, model, instructions, previous_response_id, store, stream, temperature, text, tools, max_infer_iters
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_openai_responses(
|
async def list_openai_responses(
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,8 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseOutputMessageFunctionToolCall,
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
OpenAIResponseOutputMessageMCPListTools,
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
|
OpenAIResponseText,
|
||||||
|
OpenAIResponseTextFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
|
|
@ -50,7 +52,12 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChoice,
|
OpenAIChoice,
|
||||||
OpenAIDeveloperMessageParam,
|
OpenAIDeveloperMessageParam,
|
||||||
OpenAIImageURL,
|
OpenAIImageURL,
|
||||||
|
OpenAIJSONSchema,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
|
OpenAIResponseFormatJSONObject,
|
||||||
|
OpenAIResponseFormatJSONSchema,
|
||||||
|
OpenAIResponseFormatParam,
|
||||||
|
OpenAIResponseFormatText,
|
||||||
OpenAISystemMessageParam,
|
OpenAISystemMessageParam,
|
||||||
OpenAIToolMessageParam,
|
OpenAIToolMessageParam,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
|
|
@ -158,6 +165,21 @@ async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> Open
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam:
|
||||||
|
"""
|
||||||
|
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
||||||
|
"""
|
||||||
|
if not text.format or text.format["type"] == "text":
|
||||||
|
return OpenAIResponseFormatText(type="text")
|
||||||
|
if text.format["type"] == "json_object":
|
||||||
|
return OpenAIResponseFormatJSONObject()
|
||||||
|
if text.format["type"] == "json_schema":
|
||||||
|
return OpenAIResponseFormatJSONSchema(
|
||||||
|
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unsupported text format: {text.format}")
|
||||||
|
|
||||||
|
|
||||||
async def _get_message_type_by_role(role: str):
|
async def _get_message_type_by_role(role: str):
|
||||||
role_to_type = {
|
role_to_type = {
|
||||||
"user": OpenAIUserMessageParam,
|
"user": OpenAIUserMessageParam,
|
||||||
|
|
@ -180,6 +202,7 @@ class ChatCompletionContext(BaseModel):
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||||
stream: bool
|
stream: bool
|
||||||
temperature: float | None
|
temperature: float | None
|
||||||
|
response_format: OpenAIResponseFormatParam
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponsesImpl:
|
class OpenAIResponsesImpl:
|
||||||
|
|
@ -258,6 +281,18 @@ class OpenAIResponsesImpl:
|
||||||
"""
|
"""
|
||||||
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||||
|
|
||||||
|
def _is_function_tool_call(
|
||||||
|
self,
|
||||||
|
tool_call: OpenAIChatCompletionToolCall,
|
||||||
|
tools: list[OpenAIResponseInputTool],
|
||||||
|
) -> bool:
|
||||||
|
if not tool_call.function:
|
||||||
|
return False
|
||||||
|
for t in tools:
|
||||||
|
if t.type == "function" and t.name == tool_call.function.name:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
async def _process_response_choices(
|
async def _process_response_choices(
|
||||||
self,
|
self,
|
||||||
chat_response: OpenAIChatCompletion,
|
chat_response: OpenAIChatCompletion,
|
||||||
|
|
@ -270,7 +305,7 @@ class OpenAIResponsesImpl:
|
||||||
for choice in chat_response.choices:
|
for choice in chat_response.choices:
|
||||||
if choice.message.tool_calls and tools:
|
if choice.message.tool_calls and tools:
|
||||||
# Assume if the first tool is a function, all tools are functions
|
# Assume if the first tool is a function, all tools are functions
|
||||||
if tools[0].type == "function":
|
if self._is_function_tool_call(choice.message.tool_calls[0], tools):
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
output_messages.append(
|
output_messages.append(
|
||||||
OpenAIResponseOutputMessageFunctionToolCall(
|
OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
|
@ -331,9 +366,12 @@ class OpenAIResponsesImpl:
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
|
text: OpenAIResponseText | None = None,
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
|
max_infer_iters: int | None = 10,
|
||||||
):
|
):
|
||||||
stream = False if stream is None else stream
|
stream = False if stream is None else stream
|
||||||
|
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||||
|
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
|
||||||
|
|
@ -342,6 +380,9 @@ class OpenAIResponsesImpl:
|
||||||
messages = await _convert_response_input_to_chat_messages(input)
|
messages = await _convert_response_input_to_chat_messages(input)
|
||||||
await self._prepend_instructions(messages, instructions)
|
await self._prepend_instructions(messages, instructions)
|
||||||
|
|
||||||
|
# Structured outputs
|
||||||
|
response_format = await _convert_response_text_to_chat_response_format(text)
|
||||||
|
|
||||||
# Tool setup
|
# Tool setup
|
||||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||||
|
|
@ -356,65 +397,111 @@ class OpenAIResponsesImpl:
|
||||||
mcp_tool_to_server=mcp_tool_to_server,
|
mcp_tool_to_server=mcp_tool_to_server,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
inference_result = await self.inference_api.openai_chat_completion(
|
# Fork to streaming vs non-streaming - let each handle ALL inference rounds
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
tools=chat_tools,
|
|
||||||
stream=stream,
|
|
||||||
temperature=temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._create_streaming_response(
|
return self._create_streaming_response(
|
||||||
inference_result=inference_result,
|
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
output_messages=output_messages,
|
output_messages=output_messages,
|
||||||
input=input,
|
input=input,
|
||||||
model=model,
|
model=model,
|
||||||
store=store,
|
store=store,
|
||||||
|
text=text,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
max_infer_iters=max_infer_iters,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await self._create_non_streaming_response(
|
return await self._create_non_streaming_response(
|
||||||
inference_result=inference_result,
|
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
output_messages=output_messages,
|
output_messages=output_messages,
|
||||||
input=input,
|
input=input,
|
||||||
model=model,
|
model=model,
|
||||||
store=store,
|
store=store,
|
||||||
|
text=text,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
max_infer_iters=max_infer_iters,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _create_non_streaming_response(
|
async def _create_non_streaming_response(
|
||||||
self,
|
self,
|
||||||
inference_result: Any,
|
|
||||||
ctx: ChatCompletionContext,
|
ctx: ChatCompletionContext,
|
||||||
output_messages: list[OpenAIResponseOutput],
|
output_messages: list[OpenAIResponseOutput],
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
model: str,
|
model: str,
|
||||||
store: bool | None,
|
store: bool | None,
|
||||||
|
text: OpenAIResponseText,
|
||||||
tools: list[OpenAIResponseInputTool] | None,
|
tools: list[OpenAIResponseInputTool] | None,
|
||||||
|
max_infer_iters: int,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
|
n_iter = 0
|
||||||
|
messages = ctx.messages.copy()
|
||||||
|
|
||||||
# Process response choices (tool execution and message creation)
|
while True:
|
||||||
output_messages.extend(
|
# Do inference (including the first one)
|
||||||
await self._process_response_choices(
|
inference_result = await self.inference_api.openai_chat_completion(
|
||||||
chat_response=chat_response,
|
model=ctx.model,
|
||||||
|
messages=messages,
|
||||||
|
tools=ctx.tools,
|
||||||
|
stream=False,
|
||||||
|
temperature=ctx.temperature,
|
||||||
|
response_format=ctx.response_format,
|
||||||
|
)
|
||||||
|
completion = OpenAIChatCompletion(**inference_result.model_dump())
|
||||||
|
|
||||||
|
# Separate function vs non-function tool calls
|
||||||
|
function_tool_calls = []
|
||||||
|
non_function_tool_calls = []
|
||||||
|
|
||||||
|
for choice in completion.choices:
|
||||||
|
if choice.message.tool_calls and tools:
|
||||||
|
for tool_call in choice.message.tool_calls:
|
||||||
|
if self._is_function_tool_call(tool_call, tools):
|
||||||
|
function_tool_calls.append(tool_call)
|
||||||
|
else:
|
||||||
|
non_function_tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
# Process response choices based on tool call types
|
||||||
|
if function_tool_calls:
|
||||||
|
# For function tool calls, use existing logic and return immediately
|
||||||
|
current_output_messages = await self._process_response_choices(
|
||||||
|
chat_response=completion,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
)
|
output_messages.extend(current_output_messages)
|
||||||
|
break
|
||||||
|
elif non_function_tool_calls:
|
||||||
|
# For non-function tool calls, execute them and continue loop
|
||||||
|
for choice in completion.choices:
|
||||||
|
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
|
||||||
|
output_messages.extend(tool_outputs)
|
||||||
|
|
||||||
|
# Add assistant message and tool responses to messages for next iteration
|
||||||
|
messages.append(choice.message)
|
||||||
|
messages.extend(tool_response_messages)
|
||||||
|
|
||||||
|
n_iter += 1
|
||||||
|
if n_iter >= max_infer_iters:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Continue with next iteration of the loop
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# No tool calls - convert response to message and we're done
|
||||||
|
for choice in completion.choices:
|
||||||
|
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||||
|
break
|
||||||
|
|
||||||
response = OpenAIResponseObject(
|
response = OpenAIResponseObject(
|
||||||
created_at=chat_response.created,
|
created_at=completion.created,
|
||||||
id=f"resp-{uuid.uuid4()}",
|
id=f"resp-{uuid.uuid4()}",
|
||||||
model=model,
|
model=model,
|
||||||
object="response",
|
object="response",
|
||||||
status="completed",
|
status="completed",
|
||||||
output=output_messages,
|
output=output_messages,
|
||||||
|
text=text,
|
||||||
)
|
)
|
||||||
logger.debug(f"OpenAI Responses response: {response}")
|
logger.debug(f"OpenAI Responses response: {response}")
|
||||||
|
|
||||||
|
|
@ -429,13 +516,14 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
async def _create_streaming_response(
|
async def _create_streaming_response(
|
||||||
self,
|
self,
|
||||||
inference_result: Any,
|
|
||||||
ctx: ChatCompletionContext,
|
ctx: ChatCompletionContext,
|
||||||
output_messages: list[OpenAIResponseOutput],
|
output_messages: list[OpenAIResponseOutput],
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
model: str,
|
model: str,
|
||||||
store: bool | None,
|
store: bool | None,
|
||||||
|
text: OpenAIResponseText,
|
||||||
tools: list[OpenAIResponseInputTool] | None,
|
tools: list[OpenAIResponseInputTool] | None,
|
||||||
|
max_infer_iters: int | None,
|
||||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
# Create initial response and emit response.created immediately
|
# Create initial response and emit response.created immediately
|
||||||
response_id = f"resp-{uuid.uuid4()}"
|
response_id = f"resp-{uuid.uuid4()}"
|
||||||
|
|
@ -448,13 +536,27 @@ class OpenAIResponsesImpl:
|
||||||
object="response",
|
object="response",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
output=output_messages.copy(),
|
output=output_messages.copy(),
|
||||||
|
text=text,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Emit response.created immediately
|
# Emit response.created immediately
|
||||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||||
|
|
||||||
# For streaming, inference_result is an async iterator of chunks
|
# Implement tool execution loop for streaming - handle ALL inference rounds including the first
|
||||||
# Stream chunks and emit delta events as they arrive
|
n_iter = 0
|
||||||
|
messages = ctx.messages.copy()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
current_inference_result = await self.inference_api.openai_chat_completion(
|
||||||
|
model=ctx.model,
|
||||||
|
messages=messages,
|
||||||
|
tools=ctx.tools,
|
||||||
|
stream=True,
|
||||||
|
temperature=ctx.temperature,
|
||||||
|
response_format=ctx.response_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process streaming chunks and build complete response
|
||||||
chat_response_id = ""
|
chat_response_id = ""
|
||||||
chat_response_content = []
|
chat_response_content = []
|
||||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||||
|
|
@ -466,7 +568,7 @@ class OpenAIResponsesImpl:
|
||||||
# Create a placeholder message item for delta events
|
# Create a placeholder message item for delta events
|
||||||
message_item_id = f"msg_{uuid.uuid4()}"
|
message_item_id = f"msg_{uuid.uuid4()}"
|
||||||
|
|
||||||
async for chunk in inference_result:
|
async for chunk in current_inference_result:
|
||||||
chat_response_id = chunk.id
|
chat_response_id = chunk.id
|
||||||
chunk_created = chunk.created
|
chunk_created = chunk.created
|
||||||
chunk_model = chunk.model
|
chunk_model = chunk.model
|
||||||
|
|
@ -487,12 +589,12 @@ class OpenAIResponsesImpl:
|
||||||
if chunk_choice.finish_reason:
|
if chunk_choice.finish_reason:
|
||||||
chunk_finish_reason = chunk_choice.finish_reason
|
chunk_finish_reason = chunk_choice.finish_reason
|
||||||
|
|
||||||
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
# Aggregate tool call arguments across chunks
|
||||||
if chunk_choice.delta.tool_calls:
|
if chunk_choice.delta.tool_calls:
|
||||||
for tool_call in chunk_choice.delta.tool_calls:
|
for tool_call in chunk_choice.delta.tool_calls:
|
||||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||||
if response_tool_call:
|
if response_tool_call:
|
||||||
# Don't attempt to concatenate arguments if we don't have any new arguments
|
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
|
||||||
if tool_call.function.arguments:
|
if tool_call.function.arguments:
|
||||||
# Guard against an initial None argument before we concatenate
|
# Guard against an initial None argument before we concatenate
|
||||||
response_tool_call.function.arguments = (
|
response_tool_call.function.arguments = (
|
||||||
|
|
@ -513,7 +615,7 @@ class OpenAIResponsesImpl:
|
||||||
content="".join(chat_response_content),
|
content="".join(chat_response_content),
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
chat_response_obj = OpenAIChatCompletion(
|
current_response = OpenAIChatCompletion(
|
||||||
id=chat_response_id,
|
id=chat_response_id,
|
||||||
choices=[
|
choices=[
|
||||||
OpenAIChoice(
|
OpenAIChoice(
|
||||||
|
|
@ -526,14 +628,49 @@ class OpenAIResponsesImpl:
|
||||||
model=chunk_model,
|
model=chunk_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process response choices (tool execution and message creation)
|
# Separate function vs non-function tool calls
|
||||||
output_messages.extend(
|
function_tool_calls = []
|
||||||
await self._process_response_choices(
|
non_function_tool_calls = []
|
||||||
chat_response=chat_response_obj,
|
|
||||||
|
for choice in current_response.choices:
|
||||||
|
if choice.message.tool_calls and tools:
|
||||||
|
for tool_call in choice.message.tool_calls:
|
||||||
|
if self._is_function_tool_call(tool_call, tools):
|
||||||
|
function_tool_calls.append(tool_call)
|
||||||
|
else:
|
||||||
|
non_function_tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
# Process response choices based on tool call types
|
||||||
|
if function_tool_calls:
|
||||||
|
# For function tool calls, use existing logic and break
|
||||||
|
current_output_messages = await self._process_response_choices(
|
||||||
|
chat_response=current_response,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
)
|
output_messages.extend(current_output_messages)
|
||||||
|
break
|
||||||
|
elif non_function_tool_calls:
|
||||||
|
# For non-function tool calls, execute them and continue loop
|
||||||
|
for choice in current_response.choices:
|
||||||
|
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
|
||||||
|
output_messages.extend(tool_outputs)
|
||||||
|
|
||||||
|
# Add assistant message and tool responses to messages for next iteration
|
||||||
|
messages.append(choice.message)
|
||||||
|
messages.extend(tool_response_messages)
|
||||||
|
|
||||||
|
n_iter += 1
|
||||||
|
if n_iter >= (max_infer_iters or 10):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Continue with next iteration of the loop
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# No tool calls - convert response to message and we're done
|
||||||
|
for choice in current_response.choices:
|
||||||
|
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||||
|
break
|
||||||
|
|
||||||
# Create final response
|
# Create final response
|
||||||
final_response = OpenAIResponseObject(
|
final_response = OpenAIResponseObject(
|
||||||
|
|
@ -542,6 +679,7 @@ class OpenAIResponsesImpl:
|
||||||
model=model,
|
model=model,
|
||||||
object="response",
|
object="response",
|
||||||
status="completed",
|
status="completed",
|
||||||
|
text=text,
|
||||||
output=output_messages,
|
output=output_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -646,6 +784,30 @@ class OpenAIResponsesImpl:
|
||||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||||
return chat_tools, mcp_tool_to_server, mcp_list_message
|
return chat_tools, mcp_tool_to_server, mcp_list_message
|
||||||
|
|
||||||
|
async def _execute_tool_calls_only(
|
||||||
|
self,
|
||||||
|
choice: OpenAIChoice,
|
||||||
|
ctx: ChatCompletionContext,
|
||||||
|
) -> tuple[list[OpenAIResponseOutput], list[OpenAIMessageParam]]:
|
||||||
|
"""Execute tool calls and return output messages and tool response messages for next inference."""
|
||||||
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
tool_response_messages: list[OpenAIMessageParam] = []
|
||||||
|
|
||||||
|
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||||
|
return output_messages, tool_response_messages
|
||||||
|
|
||||||
|
if not choice.message.tool_calls:
|
||||||
|
return output_messages, tool_response_messages
|
||||||
|
|
||||||
|
for tool_call in choice.message.tool_calls:
|
||||||
|
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
|
||||||
|
if tool_call_log:
|
||||||
|
output_messages.append(tool_call_log)
|
||||||
|
if further_input:
|
||||||
|
tool_response_messages.append(further_input)
|
||||||
|
|
||||||
|
return output_messages, tool_response_messages
|
||||||
|
|
||||||
async def _execute_tool_and_return_final_output(
|
async def _execute_tool_and_return_final_output(
|
||||||
self,
|
self,
|
||||||
choice: OpenAIChoice,
|
choice: OpenAIChoice,
|
||||||
|
|
@ -772,5 +934,8 @@ class OpenAIResponsesImpl:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||||
|
else:
|
||||||
|
text = str(error_exc)
|
||||||
|
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||||
|
|
||||||
return message, input_message
|
return message, input_message
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,10 @@ import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||||
from llama_stack.distribution.access_control import check_access
|
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.access_control.datatypes import AccessRule
|
||||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
from llama_stack.distribution.datatypes import User
|
||||||
|
from llama_stack.distribution.request_headers import get_authenticated_user
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -22,7 +23,9 @@ class AgentSessionInfo(Session):
|
||||||
# TODO: is this used anywhere?
|
# TODO: is this used anywhere?
|
||||||
vector_db_id: str | None = None
|
vector_db_id: str | None = None
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
access_attributes: AccessAttributes | None = None
|
owner: User | None = None
|
||||||
|
identifier: str | None = None
|
||||||
|
type: str = "session"
|
||||||
|
|
||||||
|
|
||||||
class AgentInfo(AgentConfig):
|
class AgentInfo(AgentConfig):
|
||||||
|
|
@ -30,24 +33,27 @@ class AgentInfo(AgentConfig):
|
||||||
|
|
||||||
|
|
||||||
class AgentPersistence:
|
class AgentPersistence:
|
||||||
def __init__(self, agent_id: str, kvstore: KVStore):
|
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.kvstore = kvstore
|
self.kvstore = kvstore
|
||||||
|
self.policy = policy
|
||||||
|
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# Get current user's auth attributes for new sessions
|
# Get current user's auth attributes for new sessions
|
||||||
auth_attributes = get_auth_attributes()
|
user = get_authenticated_user()
|
||||||
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
|
|
||||||
|
|
||||||
session_info = AgentSessionInfo(
|
session_info = AgentSessionInfo(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_name=name,
|
session_name=name,
|
||||||
started_at=datetime.now(timezone.utc),
|
started_at=datetime.now(timezone.utc),
|
||||||
access_attributes=access_attributes,
|
owner=user,
|
||||||
turns=[],
|
turns=[],
|
||||||
|
identifier=name, # should this be qualified in any way?
|
||||||
)
|
)
|
||||||
|
if not is_action_allowed(self.policy, "create", session_info, user):
|
||||||
|
raise AccessDeniedError()
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
|
|
@ -73,10 +79,10 @@ class AgentPersistence:
|
||||||
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
||||||
"""Check if current user has access to the session."""
|
"""Check if current user has access to the session."""
|
||||||
# Handle backward compatibility for old sessions without access control
|
# Handle backward compatibility for old sessions without access control
|
||||||
if not hasattr(session_info, "access_attributes"):
|
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
|
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
||||||
|
|
||||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||||
|
|
|
||||||
20
llama_stack/providers/inline/files/localfs/__init__.py
Normal file
20
llama_stack/providers/inline/files/localfs/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
|
from .config import LocalfsFilesImplConfig
|
||||||
|
from .files import LocalfsFilesImpl
|
||||||
|
|
||||||
|
__all__ = ["LocalfsFilesImpl", "LocalfsFilesImplConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any]):
|
||||||
|
impl = LocalfsFilesImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
31
llama_stack/providers/inline/files/localfs/config.py
Normal file
31
llama_stack/providers/inline/files/localfs/config.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
class LocalfsFilesImplConfig(BaseModel):
|
||||||
|
storage_dir: str = Field(
|
||||||
|
description="Directory to store uploaded files",
|
||||||
|
)
|
||||||
|
metadata_store: SqlStoreConfig = Field(
|
||||||
|
description="SQL store configuration for file metadata",
|
||||||
|
)
|
||||||
|
ttl_secs: int = 365 * 24 * 60 * 60 # 1 year
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"storage_dir": "${env.FILES_STORAGE_DIR:" + __distro_dir__ + "/files}",
|
||||||
|
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="files_metadata.db",
|
||||||
|
),
|
||||||
|
}
|
||||||
214
llama_stack/providers/inline/files/localfs/files.py
Normal file
214
llama_stack/providers/inline/files/localfs/files.py
Normal file
|
|
@ -0,0 +1,214 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import File, Form, Response, UploadFile
|
||||||
|
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
|
from llama_stack.apis.files import (
|
||||||
|
Files,
|
||||||
|
ListOpenAIFileResponse,
|
||||||
|
OpenAIFileDeleteResponse,
|
||||||
|
OpenAIFileObject,
|
||||||
|
OpenAIFilePurpose,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
|
||||||
|
|
||||||
|
from .config import LocalfsFilesImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
class LocalfsFilesImpl(Files):
|
||||||
|
def __init__(self, config: LocalfsFilesImplConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.sql_store: SqlStore | None = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize the files provider by setting up storage directory and metadata database."""
|
||||||
|
# Create storage directory if it doesn't exist
|
||||||
|
storage_path = Path(self.config.storage_dir)
|
||||||
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize SQL store for metadata
|
||||||
|
self.sql_store = sqlstore_impl(self.config.metadata_store)
|
||||||
|
await self.sql_store.create_table(
|
||||||
|
"openai_files",
|
||||||
|
{
|
||||||
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||||
|
"filename": ColumnType.STRING,
|
||||||
|
"purpose": ColumnType.STRING,
|
||||||
|
"bytes": ColumnType.INTEGER,
|
||||||
|
"created_at": ColumnType.INTEGER,
|
||||||
|
"expires_at": ColumnType.INTEGER,
|
||||||
|
"file_path": ColumnType.STRING, # Path to actual file on disk
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_file_id(self) -> str:
|
||||||
|
"""Generate a unique file ID for OpenAI API."""
|
||||||
|
return f"file-{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
def _get_file_path(self, file_id: str) -> Path:
|
||||||
|
"""Get the filesystem path for a file ID."""
|
||||||
|
return Path(self.config.storage_dir) / file_id
|
||||||
|
|
||||||
|
# OpenAI Files API Implementation
|
||||||
|
async def openai_upload_file(
|
||||||
|
self,
|
||||||
|
file: Annotated[UploadFile, File()],
|
||||||
|
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
"""Upload a file that can be used across various endpoints."""
|
||||||
|
if not self.sql_store:
|
||||||
|
raise RuntimeError("Files provider not initialized")
|
||||||
|
|
||||||
|
file_id = self._generate_file_id()
|
||||||
|
file_path = self._get_file_path(file_id)
|
||||||
|
|
||||||
|
content = await file.read()
|
||||||
|
file_size = len(content)
|
||||||
|
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
created_at = int(time.time())
|
||||||
|
expires_at = created_at + self.config.ttl_secs
|
||||||
|
|
||||||
|
await self.sql_store.insert(
|
||||||
|
"openai_files",
|
||||||
|
{
|
||||||
|
"id": file_id,
|
||||||
|
"filename": file.filename or "uploaded_file",
|
||||||
|
"purpose": purpose.value,
|
||||||
|
"bytes": file_size,
|
||||||
|
"created_at": created_at,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
"file_path": file_path.as_posix(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAIFileObject(
|
||||||
|
id=file_id,
|
||||||
|
filename=file.filename or "uploaded_file",
|
||||||
|
purpose=purpose,
|
||||||
|
bytes=file_size,
|
||||||
|
created_at=created_at,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_list_files(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 10000,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
purpose: OpenAIFilePurpose | None = None,
|
||||||
|
) -> ListOpenAIFileResponse:
|
||||||
|
"""Returns a list of files that belong to the user's organization."""
|
||||||
|
if not self.sql_store:
|
||||||
|
raise RuntimeError("Files provider not initialized")
|
||||||
|
|
||||||
|
# TODO: Implement 'after' pagination properly
|
||||||
|
if after:
|
||||||
|
raise NotImplementedError("After pagination not yet implemented")
|
||||||
|
|
||||||
|
where = None
|
||||||
|
if purpose:
|
||||||
|
where = {"purpose": purpose.value}
|
||||||
|
|
||||||
|
rows = await self.sql_store.fetch_all(
|
||||||
|
"openai_files",
|
||||||
|
where=where,
|
||||||
|
order_by=[("created_at", order.value if order else Order.desc.value)],
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
files = [
|
||||||
|
OpenAIFileObject(
|
||||||
|
id=row["id"],
|
||||||
|
filename=row["filename"],
|
||||||
|
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||||
|
bytes=row["bytes"],
|
||||||
|
created_at=row["created_at"],
|
||||||
|
expires_at=row["expires_at"],
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
return ListOpenAIFileResponse(
|
||||||
|
data=files,
|
||||||
|
has_more=False, # TODO: Implement proper pagination
|
||||||
|
first_id=files[0].id if files else "",
|
||||||
|
last_id=files[-1].id if files else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||||
|
"""Returns information about a specific file."""
|
||||||
|
if not self.sql_store:
|
||||||
|
raise RuntimeError("Files provider not initialized")
|
||||||
|
|
||||||
|
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||||
|
if not row:
|
||||||
|
raise ValueError(f"File with id {file_id} not found")
|
||||||
|
|
||||||
|
return OpenAIFileObject(
|
||||||
|
id=row["id"],
|
||||||
|
filename=row["filename"],
|
||||||
|
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||||
|
bytes=row["bytes"],
|
||||||
|
created_at=row["created_at"],
|
||||||
|
expires_at=row["expires_at"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||||
|
"""Delete a file."""
|
||||||
|
if not self.sql_store:
|
||||||
|
raise RuntimeError("Files provider not initialized")
|
||||||
|
|
||||||
|
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||||
|
if not row:
|
||||||
|
raise ValueError(f"File with id {file_id} not found")
|
||||||
|
|
||||||
|
# Delete physical file
|
||||||
|
file_path = Path(row["file_path"])
|
||||||
|
if file_path.exists():
|
||||||
|
file_path.unlink()
|
||||||
|
|
||||||
|
# Delete metadata from database
|
||||||
|
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||||
|
|
||||||
|
return OpenAIFileDeleteResponse(
|
||||||
|
id=file_id,
|
||||||
|
deleted=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||||
|
"""Returns the contents of the specified file."""
|
||||||
|
if not self.sql_store:
|
||||||
|
raise RuntimeError("Files provider not initialized")
|
||||||
|
|
||||||
|
# Get file metadata
|
||||||
|
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||||
|
if not row:
|
||||||
|
raise ValueError(f"File with id {file_id} not found")
|
||||||
|
|
||||||
|
# Read file content
|
||||||
|
file_path = Path(row["file_path"])
|
||||||
|
if not file_path.exists():
|
||||||
|
raise ValueError(f"File content not found on disk: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Return as binary response with appropriate content type
|
||||||
|
return Response(
|
||||||
|
content=content,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
||||||
|
)
|
||||||
|
|
@ -30,7 +30,7 @@ class TelemetryConfig(BaseModel):
|
||||||
)
|
)
|
||||||
service_name: str = Field(
|
service_name: str = Field(
|
||||||
# service name is always the same, use zero-width space to avoid clutter
|
# service name is always the same, use zero-width space to avoid clutter
|
||||||
default="",
|
default="\u200b",
|
||||||
description="The service name to use for telemetry",
|
description="The service name to use for telemetry",
|
||||||
)
|
)
|
||||||
sinks: list[TelemetrySink] = Field(
|
sinks: list[TelemetrySink] = Field(
|
||||||
|
|
@ -52,7 +52,7 @@ class TelemetryConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"service_name": "${env.OTEL_SERVICE_NAME:}",
|
"service_name": "${env.OTEL_SERVICE_NAME:\u200b}",
|
||||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||||
"sqlite_db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
"sqlite_db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,22 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import ProviderSpec
|
from llama_stack.providers.datatypes import (
|
||||||
|
Api,
|
||||||
|
InlineProviderSpec,
|
||||||
|
ProviderSpec,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> list[ProviderSpec]:
|
def available_providers() -> list[ProviderSpec]:
|
||||||
return []
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.files,
|
||||||
|
provider_type="inline::localfs",
|
||||||
|
# TODO: make this dynamic according to the sql store type
|
||||||
|
pip_packages=sql_store_pip_packages,
|
||||||
|
module="llama_stack.providers.inline.files.localfs",
|
||||||
|
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from llama_stack.providers.datatypes import (
|
||||||
|
|
||||||
META_REFERENCE_DEPS = [
|
META_REFERENCE_DEPS = [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"blobfile",
|
|
||||||
"fairscale",
|
"fairscale",
|
||||||
"torch",
|
"torch",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
provider_type="inline::rag-runtime",
|
provider_type="inline::rag-runtime",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
"chardet",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
|
|
|
||||||
|
|
@ -255,7 +255,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
params = {
|
params = {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": bool(request.stream),
|
||||||
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
||||||
}
|
}
|
||||||
logger.debug(f"params to fireworks: {params}")
|
logger.debug(f"params to fireworks: {params}")
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_entry,
|
build_model_entry,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_entries = [
|
MODEL_ENTRIES = [
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"llama3.1:8b-instruct-fp16",
|
"llama3.1:8b-instruct-fp16",
|
||||||
CoreModelId.llama3_1_8b_instruct.value,
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import uuid
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -77,7 +78,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .models import model_entries
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
@ -87,7 +88,7 @@ class OllamaInferenceAdapter(
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
):
|
):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
self.register_helper = ModelRegistryHelper(model_entries)
|
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||||
self.url = url
|
self.url = url
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -480,7 +481,25 @@ class OllamaInferenceAdapter(
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
return await self.openai_client.chat.completions.create(**params) # type: ignore
|
response = await self.openai_client.chat.completions.create(**params)
|
||||||
|
return await self._adjust_ollama_chat_completion_response_ids(response)
|
||||||
|
|
||||||
|
async def _adjust_ollama_chat_completion_response_ids(
|
||||||
|
self,
|
||||||
|
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
|
||||||
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
|
id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
|
if isinstance(response, AsyncIterator):
|
||||||
|
|
||||||
|
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
|
async for chunk in response:
|
||||||
|
chunk.id = id
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return stream_with_chunk_ids()
|
||||||
|
else:
|
||||||
|
response.id = id
|
||||||
|
return response
|
||||||
|
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -72,15 +72,15 @@ class PostgresKVStoreConfig(CommonConfig):
|
||||||
table_name: str = "llamastack_kvstore"
|
table_name: str = "llamastack_kvstore"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, table_name: str = "llamastack_kvstore"):
|
def sample_run_config(cls, table_name: str = "llamastack_kvstore", **kwargs):
|
||||||
return {
|
return {
|
||||||
"type": "postgres",
|
"type": "postgres",
|
||||||
"namespace": None,
|
"namespace": None,
|
||||||
"host": "${env.POSTGRES_HOST:localhost}",
|
"host": "${env.POSTGRES_HOST:localhost}",
|
||||||
"port": "${env.POSTGRES_PORT:5432}",
|
"port": "${env.POSTGRES_PORT:5432}",
|
||||||
"db": "${env.POSTGRES_DB}",
|
"db": "${env.POSTGRES_DB:llamastack}",
|
||||||
"user": "${env.POSTGRES_USER}",
|
"user": "${env.POSTGRES_USER:llamastack}",
|
||||||
"password": "${env.POSTGRES_PASSWORD}",
|
"password": "${env.POSTGRES_PASSWORD:llamastack}",
|
||||||
"table_name": "${env.POSTGRES_TABLE_NAME:" + table_name + "}",
|
"table_name": "${env.POSTGRES_TABLE_NAME:" + table_name + "}",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||||
|
|
||||||
from .api import SqlStore
|
from .api import SqlStore
|
||||||
|
|
||||||
|
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||||
|
|
||||||
|
|
||||||
class SqlStoreType(Enum):
|
class SqlStoreType(Enum):
|
||||||
sqlite = "sqlite"
|
sqlite = "sqlite"
|
||||||
|
|
@ -72,6 +74,17 @@ class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||||
def pip_packages(self) -> list[str]:
|
def pip_packages(self) -> list[str]:
|
||||||
return super().pip_packages + ["asyncpg"]
|
return super().pip_packages + ["asyncpg"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, **kwargs):
|
||||||
|
return cls(
|
||||||
|
type="postgres",
|
||||||
|
host="${env.POSTGRES_HOST:localhost}",
|
||||||
|
port="${env.POSTGRES_PORT:5432}",
|
||||||
|
db="${env.POSTGRES_DB:llamastack}",
|
||||||
|
user="${env.POSTGRES_USER:llamastack}",
|
||||||
|
password="${env.POSTGRES_PASSWORD:llamastack}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
SqlStoreConfig = Annotated[
|
SqlStoreConfig = Annotated[
|
||||||
SqliteSqlStoreConfig | PostgresSqlStoreConfig,
|
SqliteSqlStoreConfig | PostgresSqlStoreConfig,
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/trace_store.db
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,8 @@ distribution_spec:
|
||||||
- inline::basic
|
- inline::basic
|
||||||
- inline::llm-as-judge
|
- inline::llm-as-judge
|
||||||
- inline::braintrust
|
- inline::braintrust
|
||||||
|
files:
|
||||||
|
- inline::localfs
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
- remote::brave-search
|
- remote::brave-search
|
||||||
- remote::tavily-search
|
- remote::tavily-search
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import (
|
||||||
ShieldInput,
|
ShieldInput,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
SentenceTransformersInferenceConfig,
|
SentenceTransformersInferenceConfig,
|
||||||
)
|
)
|
||||||
|
|
@ -36,6 +37,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"eval": ["inline::meta-reference"],
|
"eval": ["inline::meta-reference"],
|
||||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||||
|
"files": ["inline::localfs"],
|
||||||
"tool_runtime": [
|
"tool_runtime": [
|
||||||
"remote::brave-search",
|
"remote::brave-search",
|
||||||
"remote::tavily-search",
|
"remote::tavily-search",
|
||||||
|
|
@ -62,6 +64,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
)
|
)
|
||||||
|
files_provider = Provider(
|
||||||
|
provider_id="meta-reference-files",
|
||||||
|
provider_type="inline::localfs",
|
||||||
|
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
|
)
|
||||||
|
|
||||||
available_models = {
|
available_models = {
|
||||||
"fireworks": MODEL_ENTRIES,
|
"fireworks": MODEL_ENTRIES,
|
||||||
|
|
@ -104,6 +111,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider, embedding_provider],
|
"inference": [inference_provider, embedding_provider],
|
||||||
"vector_io": [vector_io_provider],
|
"vector_io": [vector_io_provider],
|
||||||
|
"files": [files_provider],
|
||||||
},
|
},
|
||||||
default_models=default_models + [embedding_model],
|
default_models=default_models + [embedding_model],
|
||||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||||
|
|
@ -116,6 +124,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
embedding_provider,
|
embedding_provider,
|
||||||
],
|
],
|
||||||
"vector_io": [vector_io_provider],
|
"vector_io": [vector_io_provider],
|
||||||
|
"files": [files_provider],
|
||||||
"safety": [
|
"safety": [
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="llama-guard",
|
provider_id="llama-guard",
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ apis:
|
||||||
- agents
|
- agents
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
|
- files
|
||||||
- inference
|
- inference
|
||||||
- safety
|
- safety
|
||||||
- scoring
|
- scoring
|
||||||
|
|
@ -53,7 +54,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
@ -90,6 +91,14 @@ providers:
|
||||||
provider_type: inline::braintrust
|
provider_type: inline::braintrust
|
||||||
config:
|
config:
|
||||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:~/.llama/distributions/fireworks/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/files_metadata.db
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
- provider_id: brave-search
|
- provider_id: brave-search
|
||||||
provider_type: remote::brave-search
|
provider_type: remote::brave-search
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ apis:
|
||||||
- agents
|
- agents
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
|
- files
|
||||||
- inference
|
- inference
|
||||||
- safety
|
- safety
|
||||||
- scoring
|
- scoring
|
||||||
|
|
@ -48,7 +49,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
@ -85,6 +86,14 @@ providers:
|
||||||
provider_type: inline::braintrust
|
provider_type: inline::braintrust
|
||||||
config:
|
config:
|
||||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:~/.llama/distributions/fireworks/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/files_metadata.db
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
- provider_id: brave-search
|
- provider_id: brave-search
|
||||||
provider_type: remote::brave-search
|
provider_type: remote::brave-search
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
@ -112,7 +112,7 @@ models:
|
||||||
provider_model_id: groq/llama3-8b-8192
|
provider_model_id: groq/llama3-8b-8192
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
model_id: groq/meta-llama/Llama-3.1-8B-Instruct
|
||||||
provider_id: groq
|
provider_id: groq
|
||||||
provider_model_id: groq/llama3-8b-8192
|
provider_model_id: groq/llama3-8b-8192
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
|
@ -127,7 +127,7 @@ models:
|
||||||
provider_model_id: groq/llama3-70b-8192
|
provider_model_id: groq/llama3-70b-8192
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3-70B-Instruct
|
model_id: groq/meta-llama/Llama-3-70B-Instruct
|
||||||
provider_id: groq
|
provider_id: groq
|
||||||
provider_model_id: groq/llama3-70b-8192
|
provider_model_id: groq/llama3-70b-8192
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
|
@ -137,7 +137,7 @@ models:
|
||||||
provider_model_id: groq/llama-3.3-70b-versatile
|
provider_model_id: groq/llama-3.3-70b-versatile
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
model_id: groq/meta-llama/Llama-3.3-70B-Instruct
|
||||||
provider_id: groq
|
provider_id: groq
|
||||||
provider_model_id: groq/llama-3.3-70b-versatile
|
provider_model_id: groq/llama-3.3-70b-versatile
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
|
@ -147,7 +147,7 @@ models:
|
||||||
provider_model_id: groq/llama-3.2-3b-preview
|
provider_model_id: groq/llama-3.2-3b-preview
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
model_id: groq/meta-llama/Llama-3.2-3B-Instruct
|
||||||
provider_id: groq
|
provider_id: groq
|
||||||
provider_model_id: groq/llama-3.2-3b-preview
|
provider_model_id: groq/llama-3.2-3b-preview
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
|
@ -157,7 +157,7 @@ models:
|
||||||
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
model_id: groq/meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||||
provider_id: groq
|
provider_id: groq
|
||||||
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
|
@ -167,7 +167,7 @@ models:
|
||||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
model_id: groq/meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||||
provider_id: groq
|
provider_id: groq
|
||||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
|
@ -177,7 +177,7 @@ models:
|
||||||
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
model_id: groq/meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||||
provider_id: groq
|
provider_id: groq
|
||||||
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
|
@ -187,7 +187,7 @@ models:
|
||||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
model_id: groq/meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||||
provider_id: groq
|
provider_id: groq
|
||||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/trace_store.db
|
||||||
eval:
|
eval:
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue