Merge remote-tracking branch 'origin/main' into storage_fix

This commit is contained in:
Ashwin Bharambe 2025-11-12 10:17:56 -08:00
commit 08024d44f2
89 changed files with 4786 additions and 3941 deletions

View file

@ -463,6 +463,12 @@ resources:
settings:
license: MIT
unwrap_response_fields: [data]
file_header: |
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the terms described in the LICENSE file in
the root directory of this source tree.
openapi:
transformations:

View file

@ -963,7 +963,7 @@ paths:
Optional filter to control which routes are returned. Can be an API level
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
or 'deprecated' to show deprecated routes across all levels. If not specified,
returns only non-deprecated v1 routes.
returns all non-deprecated routes.
required: false
schema:
type: string
@ -998,39 +998,6 @@ paths:
description: List models using the OpenAI API.
parameters: []
deprecated: false
post:
responses:
'200':
description: A Model.
content:
application/json:
schema:
$ref: '#/components/schemas/Model'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Models
summary: Register model.
description: >-
Register model.
Register a model.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterModelRequest'
required: true
deprecated: false
/v1/models/{model_id}:
get:
responses:
@ -1065,36 +1032,6 @@ paths:
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Models
summary: Unregister model.
description: >-
Unregister model.
Unregister a model.
parameters:
- name: model_id
in: path
description: >-
The identifier of the model to unregister.
required: true
schema:
type: string
deprecated: false
/v1/moderations:
post:
responses:
@ -1725,32 +1662,6 @@ paths:
description: List all scoring functions.
parameters: []
deprecated: false
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ScoringFunctions
summary: Register a scoring function.
description: Register a scoring function.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterScoringFunctionRequest'
required: true
deprecated: false
/v1/scoring-functions/{scoring_fn_id}:
get:
responses:
@ -1782,33 +1693,6 @@ paths:
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ScoringFunctions
summary: Unregister a scoring function.
description: Unregister a scoring function.
parameters:
- name: scoring_fn_id
in: path
description: >-
The ID of the scoring function to unregister.
required: true
schema:
type: string
deprecated: false
/v1/scoring/score:
post:
responses:
@ -1897,36 +1781,6 @@ paths:
description: List all shields.
parameters: []
deprecated: false
post:
responses:
'200':
description: A Shield.
content:
application/json:
schema:
$ref: '#/components/schemas/Shield'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
summary: Register a shield.
description: Register a shield.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterShieldRequest'
required: true
deprecated: false
/v1/shields/{identifier}:
get:
responses:
@ -1958,33 +1812,6 @@ paths:
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
summary: Unregister a shield.
description: Unregister a shield.
parameters:
- name: identifier
in: path
description: >-
The identifier of the shield to unregister.
required: true
schema:
type: string
deprecated: false
/v1/tool-runtime/invoke:
post:
responses:
@ -2080,32 +1907,6 @@ paths:
description: List tool groups with optional provider.
parameters: []
deprecated: false
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolGroups
summary: Register a tool group.
description: Register a tool group.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterToolGroupRequest'
required: true
deprecated: false
/v1/toolgroups/{toolgroup_id}:
get:
responses:
@ -2137,32 +1938,6 @@ paths:
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolGroups
summary: Unregister a tool group.
description: Unregister a tool group.
parameters:
- name: toolgroup_id
in: path
description: The ID of the tool group to unregister.
required: true
schema:
type: string
deprecated: false
/v1/tools:
get:
responses:
@ -2916,11 +2691,12 @@ paths:
responses:
'200':
description: >-
A list of InterleavedContent representing the file contents.
File contents, optionally with embeddings and metadata based on query
parameters.
content:
application/json:
schema:
$ref: '#/components/schemas/VectorStoreFileContentsResponse'
$ref: '#/components/schemas/VectorStoreFileContentResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -2951,6 +2727,20 @@ paths:
required: true
schema:
type: string
- name: include_embeddings
in: query
description: >-
Whether to include embedding vectors in the response.
required: false
schema:
$ref: '#/components/schemas/bool'
- name: include_metadata
in: query
description: >-
Whether to include chunk metadata in the response.
required: false
schema:
$ref: '#/components/schemas/bool'
deprecated: false
/v1/vector_stores/{vector_store_id}/search:
post:
@ -3171,7 +2961,7 @@ paths:
schema:
$ref: '#/components/schemas/RegisterDatasetRequest'
required: true
deprecated: false
deprecated: true
/v1beta/datasets/{dataset_id}:
get:
responses:
@ -3228,7 +3018,7 @@ paths:
required: true
schema:
type: string
deprecated: false
deprecated: true
/v1alpha/eval/benchmarks:
get:
responses:
@ -3279,7 +3069,7 @@ paths:
schema:
$ref: '#/components/schemas/RegisterBenchmarkRequest'
required: true
deprecated: false
deprecated: true
/v1alpha/eval/benchmarks/{benchmark_id}:
get:
responses:
@ -3336,7 +3126,7 @@ paths:
required: true
schema:
type: string
deprecated: false
deprecated: true
/v1alpha/eval/benchmarks/{benchmark_id}/evaluations:
post:
responses:
@ -6280,46 +6070,6 @@ components:
required:
- data
title: OpenAIListModelsResponse
ModelType:
type: string
enum:
- llm
- embedding
- rerank
title: ModelType
description: >-
Enumeration of supported model types in Llama Stack.
RegisterModelRequest:
type: object
properties:
model_id:
type: string
description: The identifier of the model to register.
provider_model_id:
type: string
description: >-
The identifier of the model in the provider.
provider_id:
type: string
description: The identifier of the provider.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Any additional metadata for this model.
model_type:
$ref: '#/components/schemas/ModelType'
description: The type of model to register.
additionalProperties: false
required:
- model_id
title: RegisterModelRequest
Model:
type: object
properties:
@ -6377,6 +6127,15 @@ components:
title: Model
description: >-
A model resource representing an AI model registered in Llama Stack.
ModelType:
type: string
enum:
- llm
- embedding
- rerank
title: ModelType
description: >-
Enumeration of supported model types in Llama Stack.
RunModerationRequest:
type: object
properties:
@ -6882,6 +6641,11 @@ components:
type: string
description: >-
(Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
input:
type: array
items:
@ -7240,6 +7004,11 @@ components:
(Optional) Additional fields to include in the response.
max_infer_iters:
type: integer
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response.
additionalProperties: false
required:
- input
@ -7321,6 +7090,11 @@ components:
type: string
description: >-
(Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
additionalProperties: false
required:
- created_at
@ -9115,61 +8889,6 @@ components:
required:
- data
title: ListScoringFunctionsResponse
ParamType:
oneOf:
- $ref: '#/components/schemas/StringType'
- $ref: '#/components/schemas/NumberType'
- $ref: '#/components/schemas/BooleanType'
- $ref: '#/components/schemas/ArrayType'
- $ref: '#/components/schemas/ObjectType'
- $ref: '#/components/schemas/JsonType'
- $ref: '#/components/schemas/UnionType'
- $ref: '#/components/schemas/ChatCompletionInputType'
- $ref: '#/components/schemas/CompletionInputType'
discriminator:
propertyName: type
mapping:
string: '#/components/schemas/StringType'
number: '#/components/schemas/NumberType'
boolean: '#/components/schemas/BooleanType'
array: '#/components/schemas/ArrayType'
object: '#/components/schemas/ObjectType'
json: '#/components/schemas/JsonType'
union: '#/components/schemas/UnionType'
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
completion_input: '#/components/schemas/CompletionInputType'
RegisterScoringFunctionRequest:
type: object
properties:
scoring_fn_id:
type: string
description: >-
The ID of the scoring function to register.
description:
type: string
description: The description of the scoring function.
return_type:
$ref: '#/components/schemas/ParamType'
description: The return type of the scoring function.
provider_scoring_fn_id:
type: string
description: >-
The ID of the provider scoring function to use for the scoring function.
provider_id:
type: string
description: >-
The ID of the provider to use for the scoring function.
params:
$ref: '#/components/schemas/ScoringFnParams'
description: >-
The parameters for the scoring function for benchmark eval, these can
be overridden for app eval.
additionalProperties: false
required:
- scoring_fn_id
- description
- return_type
title: RegisterScoringFunctionRequest
ScoreRequest:
type: object
properties:
@ -9345,35 +9064,6 @@ components:
required:
- data
title: ListShieldsResponse
RegisterShieldRequest:
type: object
properties:
shield_id:
type: string
description: >-
The identifier of the shield to register.
provider_shield_id:
type: string
description: >-
The identifier of the shield in the provider.
provider_id:
type: string
description: The identifier of the provider.
params:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The parameters of the shield.
additionalProperties: false
required:
- shield_id
title: RegisterShieldRequest
InvokeToolRequest:
type: object
properties:
@ -9634,37 +9324,6 @@ components:
title: ListToolGroupsResponse
description: >-
Response containing a list of tool groups.
RegisterToolGroupRequest:
type: object
properties:
toolgroup_id:
type: string
description: The ID of the tool group to register.
provider_id:
type: string
description: >-
The ID of the provider to use for the tool group.
mcp_endpoint:
$ref: '#/components/schemas/URL'
description: >-
The MCP endpoint to use for the tool group.
args:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
A dictionary of arguments to pass to the tool group.
additionalProperties: false
required:
- toolgroup_id
- provider_id
title: RegisterToolGroupRequest
Chunk:
type: object
properties:
@ -10447,6 +10106,8 @@ components:
title: VectorStoreFileDeleteResponse
description: >-
Response from deleting a vector store file.
bool:
type: boolean
VectorStoreContent:
type: object
properties:
@ -10458,23 +10119,16 @@ components:
text:
type: string
description: The actual text content
additionalProperties: false
required:
- type
- text
title: VectorStoreContent
embedding:
type: array
items:
type: number
description: >-
Content item from a vector store file or search result.
VectorStoreFileContentsResponse:
type: object
properties:
file_id:
type: string
description: Unique identifier for the file
filename:
type: string
description: Name of the file
attributes:
Optional embedding vector for this content chunk
chunk_metadata:
$ref: '#/components/schemas/ChunkMetadata'
description: Optional chunk metadata
metadata:
type: object
additionalProperties:
oneOf:
@ -10484,22 +10138,44 @@ components:
- type: string
- type: array
- type: object
description: Optional user-defined metadata
additionalProperties: false
required:
- type
- text
title: VectorStoreContent
description: >-
Key-value attributes associated with the file
content:
Content item from a vector store file or search result.
VectorStoreFileContentResponse:
type: object
properties:
object:
type: string
const: vector_store.file_content.page
default: vector_store.file_content.page
description: >-
The object type, which is always `vector_store.file_content.page`
data:
type: array
items:
$ref: '#/components/schemas/VectorStoreContent'
description: List of content items from the file
description: Parsed content of the file
has_more:
type: boolean
default: false
description: >-
Indicates if there are more content pages to fetch
next_page:
type: string
description: The token for the next page, if any
additionalProperties: false
required:
- file_id
- filename
- attributes
- content
title: VectorStoreFileContentsResponse
- object
- data
- has_more
title: VectorStoreFileContentResponse
description: >-
Response from retrieving the contents of a vector store file.
Represents the parsed content of a vector store file.
OpenaiSearchVectorStoreRequest:
type: object
properties:
@ -10816,68 +10492,6 @@ components:
- data
title: ListDatasetsResponse
description: Response from listing datasets.
DataSource:
oneOf:
- $ref: '#/components/schemas/URIDataSource'
- $ref: '#/components/schemas/RowsDataSource'
discriminator:
propertyName: type
mapping:
uri: '#/components/schemas/URIDataSource'
rows: '#/components/schemas/RowsDataSource'
RegisterDatasetRequest:
type: object
properties:
purpose:
type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
description: >-
The purpose of the dataset. One of: - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. {
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
contains a question column and an answer column for evaluation. { "question":
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
The dataset contains a messages column with list of messages and an answer
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
Doe. How can I help you today?"}, {"role": "user", "content": "What's
my name?"}, ], "answer": "John Doe" }
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Ensure that the data source schema is
compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The metadata for the dataset. - E.g. {"description": "My dataset"}.
dataset_id:
type: string
description: >-
The ID of the dataset. If not provided, an ID will be generated.
additionalProperties: false
required:
- purpose
- source
title: RegisterDatasetRequest
Benchmark:
type: object
properties:
@ -10945,47 +10559,6 @@ components:
required:
- data
title: ListBenchmarksResponse
RegisterBenchmarkRequest:
type: object
properties:
benchmark_id:
type: string
description: The ID of the benchmark to register.
dataset_id:
type: string
description: >-
The ID of the dataset to use for the benchmark.
scoring_functions:
type: array
items:
type: string
description: >-
The scoring functions to use for the benchmark.
provider_benchmark_id:
type: string
description: >-
The ID of the provider benchmark to use for the benchmark.
provider_id:
type: string
description: >-
The ID of the provider to use for the benchmark.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The metadata to use for the benchmark.
additionalProperties: false
required:
- benchmark_id
- dataset_id
- scoring_functions
title: RegisterBenchmarkRequest
BenchmarkConfig:
type: object
properties:
@ -11847,6 +11420,109 @@ components:
- hyperparam_search_config
- logger_config
title: SupervisedFineTuneRequest
DataSource:
oneOf:
- $ref: '#/components/schemas/URIDataSource'
- $ref: '#/components/schemas/RowsDataSource'
discriminator:
propertyName: type
mapping:
uri: '#/components/schemas/URIDataSource'
rows: '#/components/schemas/RowsDataSource'
RegisterDatasetRequest:
type: object
properties:
purpose:
type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
description: >-
The purpose of the dataset. One of: - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. {
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
contains a question column and an answer column for evaluation. { "question":
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
The dataset contains a messages column with list of messages and an answer
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
Doe. How can I help you today?"}, {"role": "user", "content": "What's
my name?"}, ], "answer": "John Doe" }
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Ensure that the data source schema is
compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The metadata for the dataset. - E.g. {"description": "My dataset"}.
dataset_id:
type: string
description: >-
The ID of the dataset. If not provided, an ID will be generated.
additionalProperties: false
required:
- purpose
- source
title: RegisterDatasetRequest
RegisterBenchmarkRequest:
type: object
properties:
benchmark_id:
type: string
description: The ID of the benchmark to register.
dataset_id:
type: string
description: >-
The ID of the dataset to use for the benchmark.
scoring_functions:
type: array
items:
type: string
description: >-
The scoring functions to use for the benchmark.
provider_benchmark_id:
type: string
description: >-
The ID of the provider benchmark to use for the benchmark.
provider_id:
type: string
description: >-
The ID of the provider to use for the benchmark.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The metadata to use for the benchmark.
additionalProperties: false
required:
- benchmark_id
- dataset_id
- scoring_functions
title: RegisterBenchmarkRequest
responses:
BadRequest400:
description: The request was invalid or malformed

View file

@ -221,7 +221,15 @@ models:
```
A Model is an instance of a "Resource" (see [Concepts](../concepts/)) and is associated with a specific inference provider (in this case, the provider with identifier `ollama`). This is an instance of a "pre-registered" model. While we always encourage the clients to register models before using them, some Stack servers may come up a list of "already known and available" models.
What's with the `provider_model_id` field? This is an identifier for the model inside the provider's model catalog. Contrast it with `model_id` which is the identifier for the same model for Llama Stack's purposes. For example, you may want to name "llama3.2:vision-11b" as "image_captioning_model" when you use it in your Stack interactions. When omitted, the server will set `provider_model_id` to be the same as `model_id`.
What's with the `provider_model_id` field? This is an identifier for the model inside the provider's model catalog. The `model_id` field is provided for configuration purposes but is not used as part of the model identifier.
**Important:** Models are identified as `provider_id/provider_model_id` in the system and when making API calls. When `provider_model_id` is omitted, the server will set it to be the same as `model_id`.
Examples:
- Config: `model_id: llama3.2`, `provider_id: ollama`, `provider_model_id: null`
→ Access as: `ollama/llama3.2`
- Config: `model_id: my-llama`, `provider_id: vllm-inference`, `provider_model_id: llama-3-2-3b`
→ Access as: `vllm-inference/llama-3-2-3b` (the `model_id` is not used in the identifier)
If you need to conditionally register a model in the configuration, such as only when specific environment variable(s) are set, this can be accomplished by utilizing a special `__disabled__` string as the default value of an environment variable substitution, as shown below:

View file

@ -19,3 +19,4 @@ This section provides an overview of the distributions available in Llama Stack.
- **[Starting Llama Stack Server](./starting_llama_stack_server.mdx)** - How to run distributions
- **[Importing as Library](./importing_as_library.mdx)** - Use distributions in your code
- **[Configuration Reference](./configuration.mdx)** - Configuration file format details
- **[Llama Stack UI](./llama_stack_ui.mdx)** - Web-based user interface for interacting with Llama Stack servers

View file

@ -0,0 +1,109 @@
---
title: Llama Stack UI
description: Web-based user interface for interacting with Llama Stack servers
sidebar_label: Llama Stack UI
sidebar_position: 8
---
# Llama Stack UI
The Llama Stack UI is a web-based interface for interacting with Llama Stack servers. Built with Next.js and React, it provides a visual way to work with agents, manage resources, and view logs.
## Features
- **Logs & Monitoring**: View chat completions, agent responses, and vector store activity
- **Vector Stores**: Create and manage vector databases for RAG (Retrieval-Augmented Generation) workflows
- **Prompt Management**: Create and manage reusable prompts
## Prerequisites
You need a running Llama Stack server. The UI is a client that connects to the Llama Stack backend.
If you don't have a Llama Stack server running yet, see the [Starting Llama Stack Server](../getting_started/starting_llama_stack_server.mdx) guide.
## Running the UI
### Option 1: Using npx (Recommended for Quick Start)
The fastest way to get started is using `npx`:
```bash
npx llama-stack-ui
```
This will start the UI server on `http://localhost:8322` (default port).
### Option 2: Using Docker
Run the UI in a container:
```bash
docker run -p 8322:8322 llamastack/ui
```
Access the UI at `http://localhost:8322`.
## Environment Variables
The UI can be configured using the following environment variables:
| Variable | Description | Default |
|----------|-------------|---------|
| `LLAMA_STACK_BACKEND_URL` | URL of your Llama Stack server | `http://localhost:8321` |
| `LLAMA_STACK_UI_PORT` | Port for the UI server | `8322` |
If the Llama Stack server is running with authentication enabled, you can configure the UI to use it by setting the following environment variables:
| Variable | Description | Default |
|----------|-------------|---------|
| `NEXTAUTH_URL` | NextAuth URL for authentication | `http://localhost:8322` |
| `GITHUB_CLIENT_ID` | GitHub OAuth client ID (optional, for authentication) | - |
| `GITHUB_CLIENT_SECRET` | GitHub OAuth client secret (optional, for authentication) | - |
### Setting Environment Variables
#### For npx:
```bash
LLAMA_STACK_BACKEND_URL=http://localhost:8321 \
LLAMA_STACK_UI_PORT=8080 \
npx llama-stack-ui
```
#### For Docker:
```bash
docker run -p 8080:8080 \
-e LLAMA_STACK_BACKEND_URL=http://localhost:8321 \
-e LLAMA_STACK_UI_PORT=8080 \
llamastack/ui
```
## Using the UI
### Managing Resources
- **Vector Stores**: Create vector databases for RAG workflows, view stored documents and embeddings
- **Prompts**: Create and manage reusable prompt templates
- **Chat Completions**: View history of chat interactions
- **Responses**: Browse detailed agent responses and tool calls
## Development
If you want to run the UI from source for development:
```bash
# From the project root
cd src/llama_stack_ui
# Install dependencies
npm install
# Set environment variables
export LLAMA_STACK_BACKEND_URL=http://localhost:8321
# Start the development server
npm run dev
```
The development server will start on `http://localhost:8322` with hot reloading enabled.

View file

@ -0,0 +1,143 @@
---
orphan: true
---
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
# OCI Distribution
The `llamastack/distribution-oci` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| files | `inline::localfs` |
| inference | `remote::oci` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
### Environment Variables
The following environment variables can be configured:
- `OCI_AUTH_TYPE`: OCI authentication type (instance_principal or config_file) (default: `instance_principal`)
- `OCI_REGION`: OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1) (default: ``)
- `OCI_COMPARTMENT_OCID`: OCI compartment ID for the Generative AI service (default: ``)
- `OCI_CONFIG_FILE_PATH`: OCI config file path (required if OCI_AUTH_TYPE is config_file) (default: `~/.oci/config`)
- `OCI_CLI_PROFILE`: OCI CLI profile name to use from config file (default: `DEFAULT`)
## Prerequisites
### Oracle Cloud Infrastructure Setup
Before using the OCI Generative AI distribution, ensure you have:
1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/)
2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy
3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models
4. **Authentication**: Configure authentication using either:
- **Instance Principal** (recommended for cloud-hosted deployments)
- **API Key** (for on-premises or development environments)
### Authentication Methods
#### Instance Principal Authentication (Recommended)
Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments.
Requirements:
- Instance must be running in an Oracle Cloud Infrastructure compartment
- Instance must have appropriate IAM policies to access Generative AI services
#### API Key Authentication
For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file.
### Required IAM Policies
Ensure your OCI user or instance has the following policy statements:
```
Allow group <group_name> to use generative-ai-inference-endpoints in compartment <compartment_name>
Allow group <group_name> to manage generative-ai-inference-endpoints in compartment <compartment_name>
```
## Supported Services
### Inference: OCI Generative AI
Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports:
- **Chat Completions**: Conversational AI with context awareness
- **Text Generation**: Complete prompts and generate text content
#### Available Models
Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models.
### Safety: Llama Guard
For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide:
- Content filtering and moderation
- Policy compliance checking
- Harmful content detection
### Vector Storage: Multiple Options
The distribution supports several vector storage providers:
- **FAISS**: Local in-memory vector search
- **ChromaDB**: Distributed vector database
- **PGVector**: PostgreSQL with vector extensions
### Additional Services
- **Dataset I/O**: Local filesystem and Hugging Face integration
- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities
- **Evaluation**: Meta reference evaluation framework
## Running Llama Stack with OCI
You can run the OCI distribution via Docker or local virtual environment.
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
```bash
OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci
```
### Configuration Examples
#### Using Instance Principal (Recommended for Production)
```bash
export OCI_AUTH_TYPE=instance_principal
export OCI_REGION=us-chicago-1
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..<your-compartment-id>
```
#### Using API Key Authentication (Development)
```bash
export OCI_AUTH_TYPE=config_file
export OCI_CONFIG_FILE_PATH=~/.oci/config
export OCI_CLI_PROFILE=DEFAULT
export OCI_REGION=us-chicago-1
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id
```
## Regional Endpoints
OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit:
https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions
## Troubleshooting
### Common Issues
1. **Authentication Errors**: Verify your OCI credentials and IAM policies
2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region
3. **Permission Denied**: Check compartment permissions and Generative AI service access
4. **Region Unavailable**: Verify the specified region supports Generative AI services
### Getting Help
For additional support:
- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)
- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues)

View file

@ -144,7 +144,7 @@ source .venv/bin/activate
```bash
uv venv client --python 3.12
source client/bin/activate
pip install llama-stack-client
uv pip install llama-stack-client
```
</TabItem>
</Tabs>

View file

@ -0,0 +1,41 @@
---
description: |
Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models.
Provider documentation
https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm
sidebar_label: Remote - Oci
title: remote::oci
---
# remote::oci
## Description
Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models.
Provider documentation
https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `oci_auth_type` | `<class 'str'>` | No | instance_principal | OCI authentication type (must be one of: instance_principal, config_file) |
| `oci_region` | `<class 'str'>` | No | us-ashburn-1 | OCI region (e.g., us-ashburn-1) |
| `oci_compartment_id` | `<class 'str'>` | No | | OCI compartment ID for the Generative AI service |
| `oci_config_file_path` | `<class 'str'>` | No | ~/.oci/config | OCI config file path (required if oci_auth_type is config_file) |
| `oci_config_profile` | `<class 'str'>` | No | DEFAULT | OCI config profile (required if oci_auth_type is config_file) |
## Sample Configuration
```yaml
oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal}
oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}
oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT}
oci_region: ${env.OCI_REGION:=us-ashburn-1}
oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=}
```

View file

@ -57,6 +57,7 @@ const sidebars: SidebarsConfig = {
'distributions/importing_as_library',
'distributions/configuration',
'distributions/starting_llama_stack_server',
'distributions/llama_stack_ui',
{
type: 'category',
label: 'Self-Hosted Distributions',

File diff suppressed because it is too large Load diff

View file

@ -162,7 +162,7 @@ paths:
schema:
$ref: '#/components/schemas/RegisterDatasetRequest'
required: true
deprecated: false
deprecated: true
/v1beta/datasets/{dataset_id}:
get:
responses:
@ -219,7 +219,7 @@ paths:
required: true
schema:
type: string
deprecated: false
deprecated: true
/v1alpha/eval/benchmarks:
get:
responses:
@ -270,7 +270,7 @@ paths:
schema:
$ref: '#/components/schemas/RegisterBenchmarkRequest'
required: true
deprecated: false
deprecated: true
/v1alpha/eval/benchmarks/{benchmark_id}:
get:
responses:
@ -327,7 +327,7 @@ paths:
required: true
schema:
type: string
deprecated: false
deprecated: true
/v1alpha/eval/benchmarks/{benchmark_id}/evaluations:
post:
responses:
@ -936,68 +936,6 @@ components:
- data
title: ListDatasetsResponse
description: Response from listing datasets.
DataSource:
oneOf:
- $ref: '#/components/schemas/URIDataSource'
- $ref: '#/components/schemas/RowsDataSource'
discriminator:
propertyName: type
mapping:
uri: '#/components/schemas/URIDataSource'
rows: '#/components/schemas/RowsDataSource'
RegisterDatasetRequest:
type: object
properties:
purpose:
type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
description: >-
The purpose of the dataset. One of: - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. {
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
contains a question column and an answer column for evaluation. { "question":
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
The dataset contains a messages column with list of messages and an answer
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
Doe. How can I help you today?"}, {"role": "user", "content": "What's
my name?"}, ], "answer": "John Doe" }
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Ensure that the data source schema is
compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The metadata for the dataset. - E.g. {"description": "My dataset"}.
dataset_id:
type: string
description: >-
The ID of the dataset. If not provided, an ID will be generated.
additionalProperties: false
required:
- purpose
- source
title: RegisterDatasetRequest
Benchmark:
type: object
properties:
@ -1065,47 +1003,6 @@ components:
required:
- data
title: ListBenchmarksResponse
RegisterBenchmarkRequest:
type: object
properties:
benchmark_id:
type: string
description: The ID of the benchmark to register.
dataset_id:
type: string
description: >-
The ID of the dataset to use for the benchmark.
scoring_functions:
type: array
items:
type: string
description: >-
The scoring functions to use for the benchmark.
provider_benchmark_id:
type: string
description: >-
The ID of the provider benchmark to use for the benchmark.
provider_id:
type: string
description: >-
The ID of the provider to use for the benchmark.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The metadata to use for the benchmark.
additionalProperties: false
required:
- benchmark_id
- dataset_id
- scoring_functions
title: RegisterBenchmarkRequest
AggregationFunctionType:
type: string
enum:
@ -2254,6 +2151,109 @@ components:
- hyperparam_search_config
- logger_config
title: SupervisedFineTuneRequest
DataSource:
oneOf:
- $ref: '#/components/schemas/URIDataSource'
- $ref: '#/components/schemas/RowsDataSource'
discriminator:
propertyName: type
mapping:
uri: '#/components/schemas/URIDataSource'
rows: '#/components/schemas/RowsDataSource'
RegisterDatasetRequest:
type: object
properties:
purpose:
type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
description: >-
The purpose of the dataset. One of: - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. {
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
contains a question column and an answer column for evaluation. { "question":
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
The dataset contains a messages column with list of messages and an answer
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
Doe. How can I help you today?"}, {"role": "user", "content": "What's
my name?"}, ], "answer": "John Doe" }
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Ensure that the data source schema is
compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The metadata for the dataset. - E.g. {"description": "My dataset"}.
dataset_id:
type: string
description: >-
The ID of the dataset. If not provided, an ID will be generated.
additionalProperties: false
required:
- purpose
- source
title: RegisterDatasetRequest
RegisterBenchmarkRequest:
type: object
properties:
benchmark_id:
type: string
description: The ID of the benchmark to register.
dataset_id:
type: string
description: >-
The ID of the dataset to use for the benchmark.
scoring_functions:
type: array
items:
type: string
description: >-
The scoring functions to use for the benchmark.
provider_benchmark_id:
type: string
description: >-
The ID of the provider benchmark to use for the benchmark.
provider_id:
type: string
description: >-
The ID of the provider to use for the benchmark.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The metadata to use for the benchmark.
additionalProperties: false
required:
- benchmark_id
- dataset_id
- scoring_functions
title: RegisterBenchmarkRequest
responses:
BadRequest400:
description: The request was invalid or malformed

View file

@ -960,7 +960,7 @@ paths:
Optional filter to control which routes are returned. Can be an API level
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
or 'deprecated' to show deprecated routes across all levels. If not specified,
returns only non-deprecated v1 routes.
returns all non-deprecated routes.
required: false
schema:
type: string
@ -995,39 +995,6 @@ paths:
description: List models using the OpenAI API.
parameters: []
deprecated: false
post:
responses:
'200':
description: A Model.
content:
application/json:
schema:
$ref: '#/components/schemas/Model'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Models
summary: Register model.
description: >-
Register model.
Register a model.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterModelRequest'
required: true
deprecated: false
/v1/models/{model_id}:
get:
responses:
@ -1062,36 +1029,6 @@ paths:
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Models
summary: Unregister model.
description: >-
Unregister model.
Unregister a model.
parameters:
- name: model_id
in: path
description: >-
The identifier of the model to unregister.
required: true
schema:
type: string
deprecated: false
/v1/moderations:
post:
responses:
@ -1722,32 +1659,6 @@ paths:
description: List all scoring functions.
parameters: []
deprecated: false
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ScoringFunctions
summary: Register a scoring function.
description: Register a scoring function.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterScoringFunctionRequest'
required: true
deprecated: false
/v1/scoring-functions/{scoring_fn_id}:
get:
responses:
@ -1779,33 +1690,6 @@ paths:
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ScoringFunctions
summary: Unregister a scoring function.
description: Unregister a scoring function.
parameters:
- name: scoring_fn_id
in: path
description: >-
The ID of the scoring function to unregister.
required: true
schema:
type: string
deprecated: false
/v1/scoring/score:
post:
responses:
@ -1894,36 +1778,6 @@ paths:
description: List all shields.
parameters: []
deprecated: false
post:
responses:
'200':
description: A Shield.
content:
application/json:
schema:
$ref: '#/components/schemas/Shield'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
summary: Register a shield.
description: Register a shield.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterShieldRequest'
required: true
deprecated: false
/v1/shields/{identifier}:
get:
responses:
@ -1955,33 +1809,6 @@ paths:
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
summary: Unregister a shield.
description: Unregister a shield.
parameters:
- name: identifier
in: path
description: >-
The identifier of the shield to unregister.
required: true
schema:
type: string
deprecated: false
/v1/tool-runtime/invoke:
post:
responses:
@ -2077,32 +1904,6 @@ paths:
description: List tool groups with optional provider.
parameters: []
deprecated: false
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolGroups
summary: Register a tool group.
description: Register a tool group.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterToolGroupRequest'
required: true
deprecated: false
/v1/toolgroups/{toolgroup_id}:
get:
responses:
@ -2134,32 +1935,6 @@ paths:
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolGroups
summary: Unregister a tool group.
description: Unregister a tool group.
parameters:
- name: toolgroup_id
in: path
description: The ID of the tool group to unregister.
required: true
schema:
type: string
deprecated: false
/v1/tools:
get:
responses:
@ -2913,11 +2688,12 @@ paths:
responses:
'200':
description: >-
A list of InterleavedContent representing the file contents.
File contents, optionally with embeddings and metadata based on query
parameters.
content:
application/json:
schema:
$ref: '#/components/schemas/VectorStoreFileContentsResponse'
$ref: '#/components/schemas/VectorStoreFileContentResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -2948,6 +2724,20 @@ paths:
required: true
schema:
type: string
- name: include_embeddings
in: query
description: >-
Whether to include embedding vectors in the response.
required: false
schema:
$ref: '#/components/schemas/bool'
- name: include_metadata
in: query
description: >-
Whether to include chunk metadata in the response.
required: false
schema:
$ref: '#/components/schemas/bool'
deprecated: false
/v1/vector_stores/{vector_store_id}/search:
post:
@ -5564,46 +5354,6 @@ components:
required:
- data
title: OpenAIListModelsResponse
ModelType:
type: string
enum:
- llm
- embedding
- rerank
title: ModelType
description: >-
Enumeration of supported model types in Llama Stack.
RegisterModelRequest:
type: object
properties:
model_id:
type: string
description: The identifier of the model to register.
provider_model_id:
type: string
description: >-
The identifier of the model in the provider.
provider_id:
type: string
description: The identifier of the provider.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Any additional metadata for this model.
model_type:
$ref: '#/components/schemas/ModelType'
description: The type of model to register.
additionalProperties: false
required:
- model_id
title: RegisterModelRequest
Model:
type: object
properties:
@ -5661,6 +5411,15 @@ components:
title: Model
description: >-
A model resource representing an AI model registered in Llama Stack.
ModelType:
type: string
enum:
- llm
- embedding
- rerank
title: ModelType
description: >-
Enumeration of supported model types in Llama Stack.
RunModerationRequest:
type: object
properties:
@ -6166,6 +5925,11 @@ components:
type: string
description: >-
(Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
input:
type: array
items:
@ -6524,6 +6288,11 @@ components:
(Optional) Additional fields to include in the response.
max_infer_iters:
type: integer
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response.
additionalProperties: false
required:
- input
@ -6605,6 +6374,11 @@ components:
type: string
description: >-
(Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
additionalProperties: false
required:
- created_at
@ -8399,61 +8173,6 @@ components:
required:
- data
title: ListScoringFunctionsResponse
ParamType:
oneOf:
- $ref: '#/components/schemas/StringType'
- $ref: '#/components/schemas/NumberType'
- $ref: '#/components/schemas/BooleanType'
- $ref: '#/components/schemas/ArrayType'
- $ref: '#/components/schemas/ObjectType'
- $ref: '#/components/schemas/JsonType'
- $ref: '#/components/schemas/UnionType'
- $ref: '#/components/schemas/ChatCompletionInputType'
- $ref: '#/components/schemas/CompletionInputType'
discriminator:
propertyName: type
mapping:
string: '#/components/schemas/StringType'
number: '#/components/schemas/NumberType'
boolean: '#/components/schemas/BooleanType'
array: '#/components/schemas/ArrayType'
object: '#/components/schemas/ObjectType'
json: '#/components/schemas/JsonType'
union: '#/components/schemas/UnionType'
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
completion_input: '#/components/schemas/CompletionInputType'
RegisterScoringFunctionRequest:
type: object
properties:
scoring_fn_id:
type: string
description: >-
The ID of the scoring function to register.
description:
type: string
description: The description of the scoring function.
return_type:
$ref: '#/components/schemas/ParamType'
description: The return type of the scoring function.
provider_scoring_fn_id:
type: string
description: >-
The ID of the provider scoring function to use for the scoring function.
provider_id:
type: string
description: >-
The ID of the provider to use for the scoring function.
params:
$ref: '#/components/schemas/ScoringFnParams'
description: >-
The parameters for the scoring function for benchmark eval, these can
be overridden for app eval.
additionalProperties: false
required:
- scoring_fn_id
- description
- return_type
title: RegisterScoringFunctionRequest
ScoreRequest:
type: object
properties:
@ -8629,35 +8348,6 @@ components:
required:
- data
title: ListShieldsResponse
RegisterShieldRequest:
type: object
properties:
shield_id:
type: string
description: >-
The identifier of the shield to register.
provider_shield_id:
type: string
description: >-
The identifier of the shield in the provider.
provider_id:
type: string
description: The identifier of the provider.
params:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The parameters of the shield.
additionalProperties: false
required:
- shield_id
title: RegisterShieldRequest
InvokeToolRequest:
type: object
properties:
@ -8918,37 +8608,6 @@ components:
title: ListToolGroupsResponse
description: >-
Response containing a list of tool groups.
RegisterToolGroupRequest:
type: object
properties:
toolgroup_id:
type: string
description: The ID of the tool group to register.
provider_id:
type: string
description: >-
The ID of the provider to use for the tool group.
mcp_endpoint:
$ref: '#/components/schemas/URL'
description: >-
The MCP endpoint to use for the tool group.
args:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
A dictionary of arguments to pass to the tool group.
additionalProperties: false
required:
- toolgroup_id
- provider_id
title: RegisterToolGroupRequest
Chunk:
type: object
properties:
@ -9731,6 +9390,8 @@ components:
title: VectorStoreFileDeleteResponse
description: >-
Response from deleting a vector store file.
bool:
type: boolean
VectorStoreContent:
type: object
properties:
@ -9742,23 +9403,16 @@ components:
text:
type: string
description: The actual text content
additionalProperties: false
required:
- type
- text
title: VectorStoreContent
embedding:
type: array
items:
type: number
description: >-
Content item from a vector store file or search result.
VectorStoreFileContentsResponse:
type: object
properties:
file_id:
type: string
description: Unique identifier for the file
filename:
type: string
description: Name of the file
attributes:
Optional embedding vector for this content chunk
chunk_metadata:
$ref: '#/components/schemas/ChunkMetadata'
description: Optional chunk metadata
metadata:
type: object
additionalProperties:
oneOf:
@ -9768,22 +9422,44 @@ components:
- type: string
- type: array
- type: object
description: Optional user-defined metadata
additionalProperties: false
required:
- type
- text
title: VectorStoreContent
description: >-
Key-value attributes associated with the file
content:
Content item from a vector store file or search result.
VectorStoreFileContentResponse:
type: object
properties:
object:
type: string
const: vector_store.file_content.page
default: vector_store.file_content.page
description: >-
The object type, which is always `vector_store.file_content.page`
data:
type: array
items:
$ref: '#/components/schemas/VectorStoreContent'
description: List of content items from the file
description: Parsed content of the file
has_more:
type: boolean
default: false
description: >-
Indicates if there are more content pages to fetch
next_page:
type: string
description: The token for the next page, if any
additionalProperties: false
required:
- file_id
- filename
- attributes
- content
title: VectorStoreFileContentsResponse
- object
- data
- has_more
title: VectorStoreFileContentResponse
description: >-
Response from retrieving the contents of a vector store file.
Represents the parsed content of a vector store file.
OpenaiSearchVectorStoreRequest:
type: object
properties:

View file

@ -963,7 +963,7 @@ paths:
Optional filter to control which routes are returned. Can be an API level
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
or 'deprecated' to show deprecated routes across all levels. If not specified,
returns only non-deprecated v1 routes.
returns all non-deprecated routes.
required: false
schema:
type: string
@ -998,39 +998,6 @@ paths:
description: List models using the OpenAI API.
parameters: []
deprecated: false
post:
responses:
'200':
description: A Model.
content:
application/json:
schema:
$ref: '#/components/schemas/Model'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Models
summary: Register model.
description: >-
Register model.
Register a model.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterModelRequest'
required: true
deprecated: false
/v1/models/{model_id}:
get:
responses:
@ -1065,36 +1032,6 @@ paths:
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Models
summary: Unregister model.
description: >-
Unregister model.
Unregister a model.
parameters:
- name: model_id
in: path
description: >-
The identifier of the model to unregister.
required: true
schema:
type: string
deprecated: false
/v1/moderations:
post:
responses:
@ -1725,32 +1662,6 @@ paths:
description: List all scoring functions.
parameters: []
deprecated: false
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ScoringFunctions
summary: Register a scoring function.
description: Register a scoring function.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterScoringFunctionRequest'
required: true
deprecated: false
/v1/scoring-functions/{scoring_fn_id}:
get:
responses:
@ -1782,33 +1693,6 @@ paths:
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ScoringFunctions
summary: Unregister a scoring function.
description: Unregister a scoring function.
parameters:
- name: scoring_fn_id
in: path
description: >-
The ID of the scoring function to unregister.
required: true
schema:
type: string
deprecated: false
/v1/scoring/score:
post:
responses:
@ -1897,36 +1781,6 @@ paths:
description: List all shields.
parameters: []
deprecated: false
post:
responses:
'200':
description: A Shield.
content:
application/json:
schema:
$ref: '#/components/schemas/Shield'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
summary: Register a shield.
description: Register a shield.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterShieldRequest'
required: true
deprecated: false
/v1/shields/{identifier}:
get:
responses:
@ -1958,33 +1812,6 @@ paths:
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
summary: Unregister a shield.
description: Unregister a shield.
parameters:
- name: identifier
in: path
description: >-
The identifier of the shield to unregister.
required: true
schema:
type: string
deprecated: false
/v1/tool-runtime/invoke:
post:
responses:
@ -2080,32 +1907,6 @@ paths:
description: List tool groups with optional provider.
parameters: []
deprecated: false
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolGroups
summary: Register a tool group.
description: Register a tool group.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterToolGroupRequest'
required: true
deprecated: false
/v1/toolgroups/{toolgroup_id}:
get:
responses:
@ -2137,32 +1938,6 @@ paths:
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolGroups
summary: Unregister a tool group.
description: Unregister a tool group.
parameters:
- name: toolgroup_id
in: path
description: The ID of the tool group to unregister.
required: true
schema:
type: string
deprecated: false
/v1/tools:
get:
responses:
@ -2916,11 +2691,12 @@ paths:
responses:
'200':
description: >-
A list of InterleavedContent representing the file contents.
File contents, optionally with embeddings and metadata based on query
parameters.
content:
application/json:
schema:
$ref: '#/components/schemas/VectorStoreFileContentsResponse'
$ref: '#/components/schemas/VectorStoreFileContentResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -2951,6 +2727,20 @@ paths:
required: true
schema:
type: string
- name: include_embeddings
in: query
description: >-
Whether to include embedding vectors in the response.
required: false
schema:
$ref: '#/components/schemas/bool'
- name: include_metadata
in: query
description: >-
Whether to include chunk metadata in the response.
required: false
schema:
$ref: '#/components/schemas/bool'
deprecated: false
/v1/vector_stores/{vector_store_id}/search:
post:
@ -3171,7 +2961,7 @@ paths:
schema:
$ref: '#/components/schemas/RegisterDatasetRequest'
required: true
deprecated: false
deprecated: true
/v1beta/datasets/{dataset_id}:
get:
responses:
@ -3228,7 +3018,7 @@ paths:
required: true
schema:
type: string
deprecated: false
deprecated: true
/v1alpha/eval/benchmarks:
get:
responses:
@ -3279,7 +3069,7 @@ paths:
schema:
$ref: '#/components/schemas/RegisterBenchmarkRequest'
required: true
deprecated: false
deprecated: true
/v1alpha/eval/benchmarks/{benchmark_id}:
get:
responses:
@ -3336,7 +3126,7 @@ paths:
required: true
schema:
type: string
deprecated: false
deprecated: true
/v1alpha/eval/benchmarks/{benchmark_id}/evaluations:
post:
responses:
@ -6280,46 +6070,6 @@ components:
required:
- data
title: OpenAIListModelsResponse
ModelType:
type: string
enum:
- llm
- embedding
- rerank
title: ModelType
description: >-
Enumeration of supported model types in Llama Stack.
RegisterModelRequest:
type: object
properties:
model_id:
type: string
description: The identifier of the model to register.
provider_model_id:
type: string
description: >-
The identifier of the model in the provider.
provider_id:
type: string
description: The identifier of the provider.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Any additional metadata for this model.
model_type:
$ref: '#/components/schemas/ModelType'
description: The type of model to register.
additionalProperties: false
required:
- model_id
title: RegisterModelRequest
Model:
type: object
properties:
@ -6377,6 +6127,15 @@ components:
title: Model
description: >-
A model resource representing an AI model registered in Llama Stack.
ModelType:
type: string
enum:
- llm
- embedding
- rerank
title: ModelType
description: >-
Enumeration of supported model types in Llama Stack.
RunModerationRequest:
type: object
properties:
@ -6882,6 +6641,11 @@ components:
type: string
description: >-
(Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
input:
type: array
items:
@ -7240,6 +7004,11 @@ components:
(Optional) Additional fields to include in the response.
max_infer_iters:
type: integer
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response.
additionalProperties: false
required:
- input
@ -7321,6 +7090,11 @@ components:
type: string
description: >-
(Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
additionalProperties: false
required:
- created_at
@ -9115,61 +8889,6 @@ components:
required:
- data
title: ListScoringFunctionsResponse
ParamType:
oneOf:
- $ref: '#/components/schemas/StringType'
- $ref: '#/components/schemas/NumberType'
- $ref: '#/components/schemas/BooleanType'
- $ref: '#/components/schemas/ArrayType'
- $ref: '#/components/schemas/ObjectType'
- $ref: '#/components/schemas/JsonType'
- $ref: '#/components/schemas/UnionType'
- $ref: '#/components/schemas/ChatCompletionInputType'
- $ref: '#/components/schemas/CompletionInputType'
discriminator:
propertyName: type
mapping:
string: '#/components/schemas/StringType'
number: '#/components/schemas/NumberType'
boolean: '#/components/schemas/BooleanType'
array: '#/components/schemas/ArrayType'
object: '#/components/schemas/ObjectType'
json: '#/components/schemas/JsonType'
union: '#/components/schemas/UnionType'
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
completion_input: '#/components/schemas/CompletionInputType'
RegisterScoringFunctionRequest:
type: object
properties:
scoring_fn_id:
type: string
description: >-
The ID of the scoring function to register.
description:
type: string
description: The description of the scoring function.
return_type:
$ref: '#/components/schemas/ParamType'
description: The return type of the scoring function.
provider_scoring_fn_id:
type: string
description: >-
The ID of the provider scoring function to use for the scoring function.
provider_id:
type: string
description: >-
The ID of the provider to use for the scoring function.
params:
$ref: '#/components/schemas/ScoringFnParams'
description: >-
The parameters for the scoring function for benchmark eval, these can
be overridden for app eval.
additionalProperties: false
required:
- scoring_fn_id
- description
- return_type
title: RegisterScoringFunctionRequest
ScoreRequest:
type: object
properties:
@ -9345,35 +9064,6 @@ components:
required:
- data
title: ListShieldsResponse
RegisterShieldRequest:
type: object
properties:
shield_id:
type: string
description: >-
The identifier of the shield to register.
provider_shield_id:
type: string
description: >-
The identifier of the shield in the provider.
provider_id:
type: string
description: The identifier of the provider.
params:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The parameters of the shield.
additionalProperties: false
required:
- shield_id
title: RegisterShieldRequest
InvokeToolRequest:
type: object
properties:
@ -9634,37 +9324,6 @@ components:
title: ListToolGroupsResponse
description: >-
Response containing a list of tool groups.
RegisterToolGroupRequest:
type: object
properties:
toolgroup_id:
type: string
description: The ID of the tool group to register.
provider_id:
type: string
description: >-
The ID of the provider to use for the tool group.
mcp_endpoint:
$ref: '#/components/schemas/URL'
description: >-
The MCP endpoint to use for the tool group.
args:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
A dictionary of arguments to pass to the tool group.
additionalProperties: false
required:
- toolgroup_id
- provider_id
title: RegisterToolGroupRequest
Chunk:
type: object
properties:
@ -10447,6 +10106,8 @@ components:
title: VectorStoreFileDeleteResponse
description: >-
Response from deleting a vector store file.
bool:
type: boolean
VectorStoreContent:
type: object
properties:
@ -10458,23 +10119,16 @@ components:
text:
type: string
description: The actual text content
additionalProperties: false
required:
- type
- text
title: VectorStoreContent
embedding:
type: array
items:
type: number
description: >-
Content item from a vector store file or search result.
VectorStoreFileContentsResponse:
type: object
properties:
file_id:
type: string
description: Unique identifier for the file
filename:
type: string
description: Name of the file
attributes:
Optional embedding vector for this content chunk
chunk_metadata:
$ref: '#/components/schemas/ChunkMetadata'
description: Optional chunk metadata
metadata:
type: object
additionalProperties:
oneOf:
@ -10484,22 +10138,44 @@ components:
- type: string
- type: array
- type: object
description: Optional user-defined metadata
additionalProperties: false
required:
- type
- text
title: VectorStoreContent
description: >-
Key-value attributes associated with the file
content:
Content item from a vector store file or search result.
VectorStoreFileContentResponse:
type: object
properties:
object:
type: string
const: vector_store.file_content.page
default: vector_store.file_content.page
description: >-
The object type, which is always `vector_store.file_content.page`
data:
type: array
items:
$ref: '#/components/schemas/VectorStoreContent'
description: List of content items from the file
description: Parsed content of the file
has_more:
type: boolean
default: false
description: >-
Indicates if there are more content pages to fetch
next_page:
type: string
description: The token for the next page, if any
additionalProperties: false
required:
- file_id
- filename
- attributes
- content
title: VectorStoreFileContentsResponse
- object
- data
- has_more
title: VectorStoreFileContentResponse
description: >-
Response from retrieving the contents of a vector store file.
Represents the parsed content of a vector store file.
OpenaiSearchVectorStoreRequest:
type: object
properties:
@ -10816,68 +10492,6 @@ components:
- data
title: ListDatasetsResponse
description: Response from listing datasets.
DataSource:
oneOf:
- $ref: '#/components/schemas/URIDataSource'
- $ref: '#/components/schemas/RowsDataSource'
discriminator:
propertyName: type
mapping:
uri: '#/components/schemas/URIDataSource'
rows: '#/components/schemas/RowsDataSource'
RegisterDatasetRequest:
type: object
properties:
purpose:
type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
description: >-
The purpose of the dataset. One of: - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. {
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
contains a question column and an answer column for evaluation. { "question":
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
The dataset contains a messages column with list of messages and an answer
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
Doe. How can I help you today?"}, {"role": "user", "content": "What's
my name?"}, ], "answer": "John Doe" }
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Ensure that the data source schema is
compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The metadata for the dataset. - E.g. {"description": "My dataset"}.
dataset_id:
type: string
description: >-
The ID of the dataset. If not provided, an ID will be generated.
additionalProperties: false
required:
- purpose
- source
title: RegisterDatasetRequest
Benchmark:
type: object
properties:
@ -10945,47 +10559,6 @@ components:
required:
- data
title: ListBenchmarksResponse
RegisterBenchmarkRequest:
type: object
properties:
benchmark_id:
type: string
description: The ID of the benchmark to register.
dataset_id:
type: string
description: >-
The ID of the dataset to use for the benchmark.
scoring_functions:
type: array
items:
type: string
description: >-
The scoring functions to use for the benchmark.
provider_benchmark_id:
type: string
description: >-
The ID of the provider benchmark to use for the benchmark.
provider_id:
type: string
description: >-
The ID of the provider to use for the benchmark.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The metadata to use for the benchmark.
additionalProperties: false
required:
- benchmark_id
- dataset_id
- scoring_functions
title: RegisterBenchmarkRequest
BenchmarkConfig:
type: object
properties:
@ -11847,6 +11420,109 @@ components:
- hyperparam_search_config
- logger_config
title: SupervisedFineTuneRequest
DataSource:
oneOf:
- $ref: '#/components/schemas/URIDataSource'
- $ref: '#/components/schemas/RowsDataSource'
discriminator:
propertyName: type
mapping:
uri: '#/components/schemas/URIDataSource'
rows: '#/components/schemas/RowsDataSource'
RegisterDatasetRequest:
type: object
properties:
purpose:
type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
description: >-
The purpose of the dataset. One of: - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. {
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
contains a question column and an answer column for evaluation. { "question":
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
The dataset contains a messages column with list of messages and an answer
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
Doe. How can I help you today?"}, {"role": "user", "content": "What's
my name?"}, ], "answer": "John Doe" }
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Ensure that the data source schema is
compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The metadata for the dataset. - E.g. {"description": "My dataset"}.
dataset_id:
type: string
description: >-
The ID of the dataset. If not provided, an ID will be generated.
additionalProperties: false
required:
- purpose
- source
title: RegisterDatasetRequest
RegisterBenchmarkRequest:
type: object
properties:
benchmark_id:
type: string
description: The ID of the benchmark to register.
dataset_id:
type: string
description: >-
The ID of the dataset to use for the benchmark.
scoring_functions:
type: array
items:
type: string
description: >-
The scoring functions to use for the benchmark.
provider_benchmark_id:
type: string
description: >-
The ID of the provider benchmark to use for the benchmark.
provider_id:
type: string
description: >-
The ID of the provider to use for the benchmark.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The metadata to use for the benchmark.
additionalProperties: false
required:
- benchmark_id
- dataset_id
- scoring_functions
title: RegisterBenchmarkRequest
responses:
BadRequest400:
description: The request was invalid or malformed

View file

@ -112,7 +112,7 @@ unit = [
"aiosqlite",
"aiohttp",
"psycopg2-binary>=2.9.0",
"pypdf",
"pypdf>=6.1.3",
"mcp",
"chardet",
"sqlalchemy",
@ -135,7 +135,7 @@ test = [
"torchvision>=0.21.0",
"chardet",
"psycopg2-binary>=2.9.0",
"pypdf",
"pypdf>=6.1.3",
"mcp",
"datasets>=4.0.0",
"autoevals",
@ -298,6 +298,7 @@ exclude = [
"^src/llama_stack/providers/remote/agents/sample/",
"^src/llama_stack/providers/remote/datasetio/huggingface/",
"^src/llama_stack/providers/remote/datasetio/nvidia/",
"^src/llama_stack/providers/remote/inference/oci/",
"^src/llama_stack/providers/remote/inference/bedrock/",
"^src/llama_stack/providers/remote/inference/nvidia/",
"^src/llama_stack/providers/remote/inference/passthrough/",

View file

@ -87,6 +87,7 @@ class Agents(Protocol):
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
),
] = None,
max_tool_calls: int | None = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a model response.
@ -97,6 +98,7 @@ class Agents(Protocol):
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response.
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response.
:returns: An OpenAIResponseObject.
"""
...

View file

@ -594,6 +594,7 @@ class OpenAIResponseObject(BaseModel):
:param truncation: (Optional) Truncation strategy applied to the response
:param usage: (Optional) Token usage information for the response
:param instructions: (Optional) System message inserted into the model's context
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response
"""
created_at: int
@ -615,6 +616,7 @@ class OpenAIResponseObject(BaseModel):
truncation: str | None = None
usage: OpenAIResponseUsage | None = None
instructions: str | None = None
max_tool_calls: int | None = None
@json_schema_type

View file

@ -74,7 +74,7 @@ class Benchmarks(Protocol):
"""
...
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA, deprecated=True)
async def register_benchmark(
self,
benchmark_id: str,
@ -95,7 +95,7 @@ class Benchmarks(Protocol):
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA, deprecated=True)
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark.

View file

@ -146,7 +146,7 @@ class ListDatasetsResponse(BaseModel):
class Datasets(Protocol):
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA, deprecated=True)
async def register_dataset(
self,
purpose: DatasetPurpose,
@ -235,7 +235,7 @@ class Datasets(Protocol):
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA, deprecated=True)
async def unregister_dataset(
self,
dataset_id: str,

View file

@ -1,43 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
)
class LogEvent:
def __init__(
self,
content: str = "",
end: str = "\n",
color="white",
):
self.content = content
self.color = color
self.end = "\n" if end is None else end
def print(self, flush=True):
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
class EventLogger:
async def log(self, event_generator):
async for chunk in event_generator:
if isinstance(chunk, ChatCompletionResponseStreamChunk):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == ChatCompletionResponseEventType.progress:
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == ChatCompletionResponseEventType.complete:
yield LogEvent("")
else:
yield LogEvent("Assistant> ", color="cyan", end="")
yield LogEvent(chunk.completion_message.content, color="yellow")

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from collections.abc import AsyncIterator
from enum import Enum
from enum import Enum, StrEnum
from typing import (
Annotated,
Any,
@ -15,28 +15,18 @@ from typing import (
)
from fastapi import Body
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import MetricResponseMixin, Order
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.responses import (
Order,
)
from llama_stack.apis.common.tracing import telemetry_traceable
from llama_stack.apis.models import Model
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
register_schema(ToolCall)
register_schema(ToolDefinition)
from enum import StrEnum
@json_schema_type
class GreedySamplingStrategy(BaseModel):
@ -201,58 +191,6 @@ class ToolResponseMessage(BaseModel):
content: InterleavedContent
@json_schema_type
class CompletionMessage(BaseModel):
"""A message containing the model's (assistant) response in a chat conversation.
:param role: Must be "assistant" to identify this as the model's response
:param content: The content of the model's response
:param stop_reason: Reason why the model stopped generating. Options are:
- `StopReason.end_of_turn`: The model finished generating the entire response.
- `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response.
- `StopReason.out_of_tokens`: The model ran out of token budget.
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
"""
role: Literal["assistant"] = "assistant"
content: InterleavedContent
stop_reason: StopReason
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
Message = Annotated[
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
Field(discriminator="role"),
]
register_schema(Message, name="Message")
@json_schema_type
class ToolResponse(BaseModel):
"""Response from a tool invocation.
:param call_id: Unique identifier for the tool call this response is for
:param tool_name: Name of the tool that was invoked
:param content: The response content from the tool
:param metadata: (Optional) Additional metadata about the tool response
"""
call_id: str
tool_name: BuiltinTool | str
content: InterleavedContent
metadata: dict[str, Any] | None = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
class ToolChoice(Enum):
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.
@ -289,22 +227,6 @@ class ChatCompletionResponseEventType(Enum):
progress = "progress"
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""An event during chat completion generation.
:param event_type: Type of the event
:param delta: Content generated since last event. This can be one or more tokens, or a tool call.
:param logprobs: Optional log probabilities for generated tokens
:param stop_reason: Optional reason why generation stopped, if complete
"""
event_type: ChatCompletionResponseEventType
delta: ContentDelta
logprobs: list[TokenLogProbs] | None = None
stop_reason: StopReason | None = None
class ResponseFormatType(StrEnum):
"""Types of formats for structured (guided) decoding.
@ -357,34 +279,6 @@ class CompletionRequest(BaseModel):
logprobs: LogProbConfig | None = None
@json_schema_type
class CompletionResponse(MetricResponseMixin):
"""Response from a completion request.
:param content: The generated completion text
:param stop_reason: Reason why generation stopped
:param logprobs: Optional log probabilities for generated tokens
"""
content: str
stop_reason: StopReason
logprobs: list[TokenLogProbs] | None = None
@json_schema_type
class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens.
:param stop_reason: Optional reason why generation stopped, if complete
:param logprobs: Optional log probabilities for generated tokens
"""
delta: str
stop_reason: StopReason | None = None
logprobs: list[TokenLogProbs] | None = None
class SystemMessageBehavior(Enum):
"""Config for how to override the default system prompt.
@ -398,70 +292,6 @@ class SystemMessageBehavior(Enum):
replace = "replace"
@json_schema_type
class ToolConfig(BaseModel):
"""Configuration for tool use.
:param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
:param system_message_behavior: (Optional) Config for how to override the default system prompt.
- `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt.
- `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string
'{{function_definitions}}' to indicate where the function definitions should be inserted.
"""
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
def model_post_init(self, __context: Any) -> None:
if isinstance(self.tool_choice, str):
try:
self.tool_choice = ToolChoice[self.tool_choice]
except KeyError:
pass
# This is an internally used class
@json_schema_type
class ChatCompletionRequest(BaseModel):
model: str
messages: list[Message]
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
response_format: ResponseFormat | None = None
stream: bool | None = False
logprobs: LogProbConfig | None = None
@json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed chat completion response.
:param event: The event containing the new content
"""
event: ChatCompletionResponseEvent
@json_schema_type
class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request.
:param completion_message: The complete response message
:param logprobs: Optional log probabilities for generated tokens
"""
completion_message: CompletionMessage
logprobs: list[TokenLogProbs] | None = None
@json_schema_type
class EmbeddingsResponse(BaseModel):
"""Response containing generated embeddings.

View file

@ -76,7 +76,7 @@ class Inspect(Protocol):
List all available API routes with their methods and implementing providers.
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns all non-deprecated routes.
:returns: Response containing information about all available routes.
"""
...

View file

@ -136,7 +136,7 @@ class Models(Protocol):
"""
...
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1)
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
async def register_model(
self,
model_id: str,
@ -158,7 +158,7 @@ class Models(Protocol):
"""
...
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
async def unregister_model(
self,
model_id: str,

View file

@ -178,7 +178,7 @@ class ScoringFunctions(Protocol):
"""
...
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1)
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
async def register_scoring_function(
self,
scoring_fn_id: str,
@ -199,7 +199,9 @@ class ScoringFunctions(Protocol):
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
@webmethod(
route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
)
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
"""Unregister a scoring function.

View file

@ -67,7 +67,7 @@ class Shields(Protocol):
"""
...
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1)
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
async def register_shield(
self,
shield_id: str,
@ -85,7 +85,7 @@ class Shields(Protocol):
"""
...
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1)
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
async def unregister_shield(self, identifier: str) -> None:
"""Unregister a shield.

View file

@ -109,7 +109,7 @@ class ListToolDefsResponse(BaseModel):
@runtime_checkable
@telemetry_traceable
class ToolGroups(Protocol):
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1)
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
async def register_tool_group(
self,
toolgroup_id: str,
@ -167,7 +167,7 @@ class ToolGroups(Protocol):
"""
...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
async def unregister_toolgroup(
self,
toolgroup_id: str,

View file

@ -10,7 +10,7 @@
# the root directory of this source tree.
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from fastapi import Body
from fastapi import Body, Query
from pydantic import BaseModel, Field
from llama_stack.apis.common.tracing import telemetry_traceable
@ -224,10 +224,16 @@ class VectorStoreContent(BaseModel):
:param type: Content type, currently only "text" is supported
:param text: The actual text content
:param embedding: Optional embedding vector for this content chunk
:param chunk_metadata: Optional chunk metadata
:param metadata: Optional user-defined metadata
"""
type: Literal["text"]
text: str
embedding: list[float] | None = None
chunk_metadata: ChunkMetadata | None = None
metadata: dict[str, Any] | None = None
@json_schema_type
@ -280,6 +286,22 @@ class VectorStoreDeleteResponse(BaseModel):
deleted: bool = True
@json_schema_type
class VectorStoreFileContentResponse(BaseModel):
"""Represents the parsed content of a vector store file.
:param object: The object type, which is always `vector_store.file_content.page`
:param data: Parsed content of the file
:param has_more: Indicates if there are more content pages to fetch
:param next_page: The token for the next page, if any
"""
object: Literal["vector_store.file_content.page"] = "vector_store.file_content.page"
data: list[VectorStoreContent]
has_more: bool = False
next_page: str | None = None
@json_schema_type
class VectorStoreChunkingStrategyAuto(BaseModel):
"""Automatic chunking strategy for vector store files.
@ -395,22 +417,6 @@ class VectorStoreListFilesResponse(BaseModel):
has_more: bool = False
@json_schema_type
class VectorStoreFileContentsResponse(BaseModel):
"""Response from retrieving the contents of a vector store file.
:param file_id: Unique identifier for the file
:param filename: Name of the file
:param attributes: Key-value attributes associated with the file
:param content: List of content items from the file
"""
file_id: str
filename: str
attributes: dict[str, Any]
content: list[VectorStoreContent]
@json_schema_type
class VectorStoreFileDeleteResponse(BaseModel):
"""Response from deleting a vector store file.
@ -732,12 +738,16 @@ class VectorIO(Protocol):
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
include_embeddings: Annotated[bool | None, Query(default=False)] = False,
include_metadata: Annotated[bool | None, Query(default=False)] = False,
) -> VectorStoreFileContentResponse:
"""Retrieves the contents of a vector store file.
:param vector_store_id: The ID of the vector store containing the file to retrieve.
:param file_id: The ID of the file to retrieve.
:returns: A list of InterleavedContent representing the file contents.
:param include_embeddings: Whether to include embedding vectors in the response.
:param include_metadata: Whether to include chunk metadata in the response.
:returns: File contents, optionally with embeddings and metadata based on query parameters.
"""
...

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib.resources
import sys
from pydantic import BaseModel
@ -12,9 +11,6 @@ from termcolor import cprint
from llama_stack.core.datatypes import BuildConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.external import load_external_apis
from llama_stack.core.utils.exec import run_command
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.distributions.template import DistributionTemplate
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
@ -101,64 +97,3 @@ def print_pip_install_help(config: BuildConfig):
for special_dep in special_deps:
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
print()
def build_image(
build_config: BuildConfig,
image_name: str,
distro_or_config: str,
run_config: str | None = None,
):
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
if build_config.external_apis_dir:
external_apis = load_external_apis(build_config)
if external_apis:
for _, api_spec in external_apis.items():
normal_deps.extend(api_spec.pip_packages)
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
script = str(importlib.resources.files("llama_stack") / "core/build_container.sh")
args = [
script,
"--distro-or-config",
distro_or_config,
"--image-name",
image_name,
"--container-base",
container_base,
"--normal-deps",
" ".join(normal_deps),
]
# When building from a config file (not a template), include the run config path in the
# build arguments
if run_config is not None:
args.extend(["--run-config", run_config])
else:
script = str(importlib.resources.files("llama_stack") / "core/build_venv.sh")
args = [
script,
"--env-name",
str(image_name),
"--normal-deps",
" ".join(normal_deps),
]
# Always pass both arguments, even if empty, to maintain consistent positional arguments
if special_deps:
args.extend(["--optional-deps", "#".join(special_deps)])
if external_provider_deps:
args.extend(
["--external-provider-deps", "#".join(external_provider_deps)]
) # the script will install external provider module, get its deps, and install those too.
return_code = run_command(args)
if return_code != 0:
log.error(
f"Failed to build target {image_name} with return code {return_code}",
)
return return_code

View file

@ -15,7 +15,6 @@ from llama_stack.apis.inspect import (
RouteInfo,
VersionInfo,
)
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.routes import get_all_api_routes
@ -46,8 +45,8 @@ class DistributionInspectImpl(Inspect):
# Helper function to determine if a route should be included based on api_filter
def should_include_route(webmethod) -> bool:
if api_filter is None:
# Default: only non-deprecated v1 APIs
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
# Default: only non-deprecated APIs
return not webmethod.deprecated
elif api_filter == "deprecated":
# Special filter: show deprecated routes regardless of their actual level
return bool(webmethod.deprecated)

View file

@ -389,6 +389,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
# Pass through params that aren't already handled as path params
if options.params:
extra_query_params = {k: v for k, v in options.params.items() if k not in path_params}
if extra_query_params:
body["extra_query"] = extra_query_params
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(matched_func, body, exclude_params=set(field_names))

View file

@ -6,7 +6,7 @@
from typing import Any
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
@ -52,7 +52,7 @@ class SafetyRouter(Safety):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")

View file

@ -24,7 +24,7 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategyStaticConfig,
VectorStoreDeleteResponse,
VectorStoreFileBatchObject,
VectorStoreFileContentsResponse,
VectorStoreFileContentResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFilesListInBatchResponse,
@ -247,6 +247,13 @@ class VectorIORouter(VectorIO):
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
# Check if provider_id is being changed (not supported)
if metadata and "provider_id" in metadata:
current_store = await self.routing_table.get_object_by_identifier("vector_store", vector_store_id)
if current_store and current_store.provider_id != metadata["provider_id"]:
raise ValueError("provider_id cannot be changed after vector store creation")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
@ -338,12 +345,19 @@ class VectorIORouter(VectorIO):
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
include_embeddings: bool | None = False,
include_metadata: bool | None = False,
) -> VectorStoreFileContentResponse:
logger.debug(
f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}, "
f"include_embeddings={include_embeddings}, include_metadata={include_metadata}"
)
return await self.routing_table.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
include_embeddings=include_embeddings,
include_metadata=include_metadata,
)
async def openai_update_vector_store_file(

View file

@ -15,7 +15,7 @@ from llama_stack.apis.vector_io.vector_io import (
SearchRankingOptions,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileContentResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
@ -195,12 +195,17 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
include_embeddings: bool | None = False,
include_metadata: bool | None = False,
) -> VectorStoreFileContentResponse:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
include_embeddings=include_embeddings,
include_metadata=include_metadata,
)
async def openai_update_vector_store_file(

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .oci import get_distribution_template # noqa: F401

View file

@ -0,0 +1,35 @@
version: 2
distribution_spec:
description: Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM
inference with scalable cloud services
providers:
inference:
- provider_type: remote::oci
vector_io:
- provider_type: inline::faiss
- provider_type: remote::chromadb
- provider_type: remote::pgvector
safety:
- provider_type: inline::llama-guard
agents:
- provider_type: inline::meta-reference
eval:
- provider_type: inline::meta-reference
datasetio:
- provider_type: remote::huggingface
- provider_type: inline::localfs
scoring:
- provider_type: inline::basic
- provider_type: inline::llm-as-judge
- provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
files:
- provider_type: inline::localfs
image_type: venv
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -0,0 +1,140 @@
---
orphan: true
---
# OCI Distribution
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} {{ model.doc_string }}`
{% endfor %}
{% endif %}
## Prerequisites
### Oracle Cloud Infrastructure Setup
Before using the OCI Generative AI distribution, ensure you have:
1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/)
2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy
3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models
4. **Authentication**: Configure authentication using either:
- **Instance Principal** (recommended for cloud-hosted deployments)
- **API Key** (for on-premises or development environments)
### Authentication Methods
#### Instance Principal Authentication (Recommended)
Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments.
Requirements:
- Instance must be running in an Oracle Cloud Infrastructure compartment
- Instance must have appropriate IAM policies to access Generative AI services
#### API Key Authentication
For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file.
### Required IAM Policies
Ensure your OCI user or instance has the following policy statements:
```
Allow group <group_name> to use generative-ai-inference-endpoints in compartment <compartment_name>
Allow group <group_name> to manage generative-ai-inference-endpoints in compartment <compartment_name>
```
## Supported Services
### Inference: OCI Generative AI
Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports:
- **Chat Completions**: Conversational AI with context awareness
- **Text Generation**: Complete prompts and generate text content
#### Available Models
Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models.
### Safety: Llama Guard
For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide:
- Content filtering and moderation
- Policy compliance checking
- Harmful content detection
### Vector Storage: Multiple Options
The distribution supports several vector storage providers:
- **FAISS**: Local in-memory vector search
- **ChromaDB**: Distributed vector database
- **PGVector**: PostgreSQL with vector extensions
### Additional Services
- **Dataset I/O**: Local filesystem and Hugging Face integration
- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities
- **Evaluation**: Meta reference evaluation framework
## Running Llama Stack with OCI
You can run the OCI distribution via Docker or local virtual environment.
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
```bash
OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci
```
### Configuration Examples
#### Using Instance Principal (Recommended for Production)
```bash
export OCI_AUTH_TYPE=instance_principal
export OCI_REGION=us-chicago-1
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..<your-compartment-id>
```
#### Using API Key Authentication (Development)
```bash
export OCI_AUTH_TYPE=config_file
export OCI_CONFIG_FILE_PATH=~/.oci/config
export OCI_CLI_PROFILE=DEFAULT
export OCI_REGION=us-chicago-1
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id
```
## Regional Endpoints
OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit:
https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions
## Troubleshooting
### Common Issues
1. **Authentication Errors**: Verify your OCI credentials and IAM policies
2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region
3. **Permission Denied**: Check compartment permissions and Generative AI service access
4. **Region Unavailable**: Verify the specified region supports Generative AI services
### Getting Help
For additional support:
- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)
- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues)

View file

@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.oci.config import OCIConfig
def get_distribution_template(name: str = "oci") -> DistributionTemplate:
providers = {
"inference": [BuildProvider(provider_type="remote::oci")],
"vector_io": [
BuildProvider(provider_type="inline::faiss"),
BuildProvider(provider_type="remote::chromadb"),
BuildProvider(provider_type="remote::pgvector"),
],
"safety": [BuildProvider(provider_type="inline::llama-guard")],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [
BuildProvider(provider_type="remote::huggingface"),
BuildProvider(provider_type="inline::localfs"),
],
"scoring": [
BuildProvider(provider_type="inline::basic"),
BuildProvider(provider_type="inline::llm-as-judge"),
BuildProvider(provider_type="inline::braintrust"),
],
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
}
inference_provider = Provider(
provider_id="oci",
provider_type="remote::oci",
config=OCIConfig.sample_run_config(),
)
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
]
return DistributionTemplate(
name=name,
distro_type="remote_hosted",
description="Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM inference with scalable cloud services",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider],
"files": [files_provider],
},
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"OCI_AUTH_TYPE": (
"instance_principal",
"OCI authentication type (instance_principal or config_file)",
),
"OCI_REGION": (
"",
"OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1)",
),
"OCI_COMPARTMENT_OCID": (
"",
"OCI compartment ID for the Generative AI service",
),
"OCI_CONFIG_FILE_PATH": (
"~/.oci/config",
"OCI config file path (required if OCI_AUTH_TYPE is config_file)",
),
"OCI_CLI_PROFILE": (
"DEFAULT",
"OCI CLI profile name to use from config file",
),
},
)

View file

@ -0,0 +1,136 @@
version: 2
image_name: oci
apis:
- agents
- datasetio
- eval
- files
- inference
- safety
- scoring
- tool_runtime
- vector_io
providers:
inference:
- provider_id: oci
provider_type: remote::oci
config:
oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal}
oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}
oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT}
oci_region: ${env.OCI_REGION:=us-ashburn-1}
oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
persistence:
namespace: vector_io::faiss
backend: kv_default
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
namespace: eval
backend: kv_default
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
namespace: datasetio::huggingface
backend: kv_default
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
namespace: datasetio::localfs
backend: kv_default
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/oci/files}
metadata_store:
table_name: files_metadata
backend: sql_default
storage:
backends:
kv_default:
type: kv_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/kvstore.db
sql_default:
type: sql_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/sql_store.db
stores:
metadata:
namespace: registry
backend: kv_default
inference:
table_name: inference_store
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_default
prompts:
namespace: prompts
backend: kv_default
registered_resources:
models: []
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
server:
port: 8321
telemetry:
enabled: true

View file

@ -26,8 +26,10 @@ from fairscale.nn.model_parallel.initialize import (
)
from termcolor import cprint
from llama_stack.models.llama.datatypes import ToolPromptFormat
from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput
from .model import Transformer

View file

@ -15,13 +15,10 @@ from pathlib import Path
from termcolor import colored
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat
from ..datatypes import (
BuiltinTool,
RawMessage,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from . import template_data
from .chat_format import ChatFormat

View file

@ -15,7 +15,7 @@ import textwrap
from datetime import datetime
from typing import Any
from llama_stack.apis.inference import (
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolDefinition,
)

View file

@ -8,8 +8,9 @@ import json
import re
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolPromptFormat
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
from ..datatypes import RecursiveType
logger = get_logger(name=__name__, category="models::llama")

View file

@ -13,7 +13,7 @@
import textwrap
from llama_stack.apis.inference import ToolDefinition
from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.models.llama.llama3.prompt_templates.base import (
PromptTemplate,
PromptTemplateGeneratorBase,

View file

@ -102,6 +102,7 @@ class MetaReferenceAgentsImpl(Agents):
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrail] | None = None,
max_tool_calls: int | None = None,
) -> OpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
result = await self.openai_responses_impl.create_openai_response(
@ -119,6 +120,7 @@ class MetaReferenceAgentsImpl(Agents):
include,
max_infer_iters,
guardrails,
max_tool_calls,
)
return result # type: ignore[no-any-return]

View file

@ -255,6 +255,7 @@ class OpenAIResponsesImpl:
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[str | ResponseGuardrailSpec] | None = None,
max_tool_calls: int | None = None,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
@ -270,6 +271,9 @@ class OpenAIResponsesImpl:
if not conversation.startswith("conv_"):
raise InvalidConversationIdError(conversation)
if max_tool_calls is not None and max_tool_calls < 1:
raise ValueError(f"Invalid {max_tool_calls=}; should be >= 1")
stream_gen = self._create_streaming_response(
input=input,
conversation=conversation,
@ -282,6 +286,7 @@ class OpenAIResponsesImpl:
tools=tools,
max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids,
max_tool_calls=max_tool_calls,
)
if stream:
@ -331,6 +336,7 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None,
max_tool_calls: int | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# These should never be None when called from create_openai_response (which sets defaults)
# but we assert here to help mypy understand the types
@ -373,6 +379,7 @@ class OpenAIResponsesImpl:
safety_api=self.safety_api,
guardrail_ids=guardrail_ids,
instructions=instructions,
max_tool_calls=max_tool_calls,
)
# Stream the response

View file

@ -115,6 +115,7 @@ class StreamingResponseOrchestrator:
safety_api,
guardrail_ids: list[str] | None = None,
prompt: OpenAIResponsePrompt | None = None,
max_tool_calls: int | None = None,
):
self.inference_api = inference_api
self.ctx = ctx
@ -126,6 +127,10 @@ class StreamingResponseOrchestrator:
self.safety_api = safety_api
self.guardrail_ids = guardrail_ids or []
self.prompt = prompt
# System message that is inserted into the model's context
self.instructions = instructions
# Max number of total calls to built-in tools that can be processed in a response
self.max_tool_calls = max_tool_calls
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
@ -139,8 +144,8 @@ class StreamingResponseOrchestrator:
self.accumulated_usage: OpenAIResponseUsage | None = None
# Track if we've sent a refusal response
self.violation_detected = False
# system message that is inserted into the model's context
self.instructions = instructions
# Track total calls made to built-in tools
self.accumulated_builtin_tool_calls = 0
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content."""
@ -186,6 +191,7 @@ class StreamingResponseOrchestrator:
usage=self.accumulated_usage,
instructions=self.instructions,
prompt=self.prompt,
max_tool_calls=self.max_tool_calls,
)
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
@ -894,6 +900,11 @@ class StreamingResponseOrchestrator:
"""Coordinate execution of both function and non-function tool calls."""
# Execute non-function tool calls
for tool_call in non_function_tool_calls:
# Check if total calls made to built-in and mcp tools exceed max_tool_calls
if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls:
logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.")
break
# Find the item_id for this tool call
matching_item_id = None
for index, item_id in completion_result_data.tool_call_item_ids.items():
@ -974,6 +985,9 @@ class StreamingResponseOrchestrator:
if tool_response_message:
next_turn_messages.append(tool_response_message)
# Track number of calls made to built-in and mcp tools
self.accumulated_builtin_tool_calls += 1
# Execute function tool calls (client-side)
for tool_call in function_tool_calls:
# Find the item_id for this tool call from our tracking dictionary

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import math
from collections.abc import Generator
from typing import Optional
import torch
@ -14,21 +13,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
from llama_stack.apis.inference import (
GreedySamplingStrategy,
JsonSchemaResponseFormat,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIResponseFormatJSONSchema,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.models.llama.datatypes import QuantizationMode
from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model, ModelFamily
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
get_default_tool_prompt_format,
)
from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig
@ -106,14 +103,6 @@ def _infer_sampling_params(sampling_params: SamplingParams):
return temperature, top_p
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
tool_config = request.tool_config
if tool_config is not None and tool_config.tool_prompt_format is not None:
return tool_config.tool_prompt_format
else:
return get_default_tool_prompt_format(request.model)
class LlamaGenerator:
def __init__(
self,
@ -157,55 +146,56 @@ class LlamaGenerator:
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def completion(
self,
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(first_request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
first_request.response_format,
),
)
def chat_completion(
self,
request_batch: list[ChatCompletionRequestWithRawContent],
) -> Generator:
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
request: OpenAIChatCompletionRequestWithExtraBody,
raw_messages: list,
):
"""Generate chat completion using OpenAI request format.
Args:
request: OpenAI chat completion request
raw_messages: Pre-converted list of RawMessage objects
"""
# Determine tool prompt format
tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json
# Prepare sampling params
sampling_params = SamplingParams()
if request.temperature is not None or request.top_p is not None:
sampling_params.strategy = TopPSamplingStrategy(
temperature=request.temperature if request.temperature is not None else 1.0,
top_p=request.top_p if request.top_p is not None else 1.0,
)
if request.max_tokens:
sampling_params.max_tokens = request.max_tokens
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
# Get logits processor for response format
logits_processor = None
if request.response_format:
if isinstance(request.response_format, OpenAIResponseFormatJSONSchema):
# Extract the actual schema from OpenAIJSONSchema TypedDict
schema_dict = request.response_format.json_schema.get("schema") or {}
json_schema_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema,
json_schema=schema_dict,
)
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
# Generate
yield from self.inner_generator.generate(
llm_inputs=[
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
for request in request_batch
],
llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(first_request.logprobs),
logprobs=False,
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
first_request.response_format,
),
logits_processor=logits_processor,
)

View file

@ -5,12 +5,19 @@
# the root directory of this source tree.
import asyncio
import time
import uuid
from collections.abc import AsyncIterator
from llama_stack.apis.inference import (
InferenceProvider,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionUsage,
OpenAIChoice,
OpenAICompletionRequestWithExtraBody,
OpenAIUserMessageParam,
ToolChoice,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
@ -19,12 +26,20 @@ from llama_stack.apis.inference.inference import (
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
JsonCustomToolGenerator,
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
@ -44,6 +59,170 @@ log = get_logger(__name__, category="inference")
SEMAPHORE = asyncio.Semaphore(1)
def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition:
"""Convert OpenAI tool format to ToolDefinition format."""
# OpenAI tools have function.name and function.parameters
return ToolDefinition(
tool_name=tool.function.name,
description=tool.function.description or "",
parameters=tool.function.parameters or {},
)
def _get_tool_choice_prompt(tool_choice, tools) -> str:
"""Generate prompt text for tool_choice behavior."""
if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto":
return ""
elif tool_choice == ToolChoice.required or tool_choice == "required":
return "You MUST use one of the provided functions/tools to answer the user query."
elif tool_choice == ToolChoice.none or tool_choice == "none":
return ""
else:
# Specific tool specified
return f"You MUST use the tool `{tool_choice}` to answer the user query."
def _raw_content_as_str(content) -> str:
"""Convert RawContent to string for system messages."""
if isinstance(content, str):
return content
elif isinstance(content, RawTextItem):
return content.text
elif isinstance(content, list):
return "\n".join(_raw_content_as_str(c) for c in content)
else:
return "<media>"
def _augment_raw_messages_for_tools_llama_3_1(
raw_messages: list[RawMessage],
tools: list,
tool_choice,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions for Llama 3.1 style models."""
messages = raw_messages.copy()
existing_system_message = None
if messages and messages[0].role == "system":
existing_system_message = messages.pop(0)
sys_content = ""
# Add tool definitions first (if present)
if tools:
# Convert OpenAI tools to ToolDefinitions
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
# For OpenAI format, all tools are custom (have string names)
tool_gen = JsonCustomToolGenerator()
tool_template = tool_gen.gen(tool_definitions)
sys_content += tool_template.render()
sys_content += "\n"
# Add default system prompt
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content += default_template.render()
# Add existing system message if present
if existing_system_message:
sys_content += "\n" + _raw_content_as_str(existing_system_message.content)
# Add tool choice prompt if needed
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
sys_content += "\n" + tool_choice_prompt
# Create new system message
new_system_message = RawMessage(
role="system",
content=[RawTextItem(text=sys_content.strip())],
)
return [new_system_message] + messages
def _augment_raw_messages_for_tools_llama_4(
raw_messages: list[RawMessage],
tools: list,
tool_choice,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models."""
messages = raw_messages.copy()
existing_system_message = None
if messages and messages[0].role == "system":
existing_system_message = messages.pop(0)
sys_content = ""
# Add tool definitions if present
if tools:
# Convert OpenAI tools to ToolDefinitions
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
# Use python_list format for Llama 4
tool_gen = PythonListCustomToolGeneratorLlama4()
system_prompt = None
if existing_system_message:
system_prompt = _raw_content_as_str(existing_system_message.content)
tool_template = tool_gen.gen(tool_definitions, system_prompt)
sys_content = tool_template.render()
elif existing_system_message:
# No tools, just use existing system message
sys_content = _raw_content_as_str(existing_system_message.content)
# Add tool choice prompt if needed
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
sys_content += "\n" + tool_choice_prompt
if sys_content:
new_system_message = RawMessage(
role="system",
content=[RawTextItem(text=sys_content.strip())],
)
return [new_system_message] + messages
return messages
def augment_raw_messages_for_tools(
raw_messages: list[RawMessage],
params: OpenAIChatCompletionRequestWithExtraBody,
llama_model,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions based on model family."""
if not params.tools:
return raw_messages
# Determine augmentation strategy based on model family
if llama_model.model_family == ModelFamily.llama3_1 or (
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
):
# Llama 3.1 and Llama 3.2 multimodal use JSON format
return _augment_raw_messages_for_tools_llama_3_1(
raw_messages,
params.tools,
params.tool_choice,
)
elif llama_model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
):
# Llama 3.2/3.3/4 use python_list format
return _augment_raw_messages_for_tools_llama_4(
raw_messages,
params.tools,
params.tool_choice,
)
else:
# Default to Llama 3.1 style
return _augment_raw_messages_for_tools_llama_3_1(
raw_messages,
params.tools,
params.tool_choice,
)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model)
@ -136,11 +315,14 @@ class MetaReferenceInferenceImpl(
self.llama_model = llama_model
log.info("Warming up...")
await self.openai_chat_completion(
params=OpenAIChatCompletionRequestWithExtraBody(
model=model_id,
messages=[{"role": "user", "content": "Hi how are you?"}],
messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")],
max_tokens=20,
)
)
log.info("Warmed up!")
def check_model(self, request) -> None:
@ -155,4 +337,207 @@ class MetaReferenceInferenceImpl(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
self.check_model(params)
# Convert OpenAI messages to RawMessages
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_openai_message_to_raw_message,
decode_assistant_message,
)
raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
# Augment messages with tool definitions if tools are present
raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model)
# Call generator's chat_completion method (works for both single-GPU and model-parallel)
if isinstance(self.generator, LlamaGenerator):
generator = self.generator.chat_completion(params, raw_messages)
else:
# Model parallel: submit task to process group
generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
# Check if streaming is requested
if params.stream:
return self._stream_chat_completion(generator, params)
# Non-streaming: collect all generated text
generated_text = ""
for result_batch in generator:
for result in result_batch:
if not result.ignore_token and result.source == "output":
generated_text += result.text
# Decode assistant message to extract tool calls and determine stop_reason
# Default to end_of_turn if generation completed normally
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
# Convert tool calls to OpenAI format
openai_tool_calls = None
if decoded_message.tool_calls:
from llama_stack.apis.inference import (
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
)
openai_tool_calls = [
OpenAIChatCompletionToolCall(
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
id=f"call_{uuid.uuid4().hex[:24]}",
type="function",
function=OpenAIChatCompletionToolCallFunction(
name=str(tc.tool_name),
arguments=tc.arguments,
),
)
for tc in decoded_message.tool_calls
]
# Determine finish_reason based on whether tool calls are present
finish_reason = "tool_calls" if openai_tool_calls else "stop"
# Extract content from decoded message
content = ""
if isinstance(decoded_message.content, str):
content = decoded_message.content
elif isinstance(decoded_message.content, list):
for item in decoded_message.content:
if isinstance(item, RawTextItem):
content += item.text
# Create OpenAI response
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
return OpenAIChatCompletion(
id=response_id,
object="chat.completion",
created=created,
model=params.model,
choices=[
OpenAIChoice(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant",
content=content,
tool_calls=openai_tool_calls,
),
finish_reason=finish_reason,
logprobs=None,
)
],
usage=OpenAIChatCompletionUsage(
prompt_tokens=0, # TODO: calculate properly
completion_tokens=0, # TODO: calculate properly
total_tokens=0, # TODO: calculate properly
),
)
async def _stream_chat_completion(
self,
generator,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Stream chat completion chunks as they're generated."""
from llama_stack.apis.inference import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoiceDelta,
OpenAIChunkChoice,
)
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
generated_text = ""
# Yield chunks as tokens are generated
for result_batch in generator:
for result in result_batch:
if result.ignore_token or result.source != "output":
continue
generated_text += result.text
# Yield delta chunk with the new text
chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(
role="assistant",
content=result.text,
),
finish_reason="",
logprobs=None,
)
],
)
yield chunk
# After generation completes, decode the full message to extract tool calls
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
# If tool calls are present, yield a final chunk with tool_calls
if decoded_message.tool_calls:
openai_tool_calls = [
OpenAIChatCompletionToolCall(
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
id=f"call_{uuid.uuid4().hex[:24]}",
type="function",
function=OpenAIChatCompletionToolCallFunction(
name=str(tc.tool_name),
arguments=tc.arguments,
),
)
for tc in decoded_message.tool_calls
]
# Yield chunk with tool_calls
chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(
role="assistant",
tool_calls=openai_tool_calls,
),
finish_reason="",
logprobs=None,
)
],
)
yield chunk
finish_reason = "tool_calls"
else:
finish_reason = "stop"
# Yield final chunk with finish_reason
final_chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(),
finish_reason=finish_reason,
logprobs=None,
)
],
)
yield final_chunk

View file

@ -4,17 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable, Generator
from copy import deepcopy
from collections.abc import Callable
from functools import partial
from typing import Any
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .parallel_utils import ModelParallelProcessGroup
@ -23,12 +18,14 @@ class ModelRunner:
def __init__(self, llama):
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: Any):
if task[0] == "chat_completion":
return self.llama.chat_completion(task[1])
task_type = task[0]
if task_type == "chat_completion":
# task[1] is [params, raw_messages]
params, raw_messages = task[1]
return self.llama.chat_completion(params, raw_messages)
else:
raise ValueError(f"Unexpected task type {task[0]}")
raise ValueError(f"Unexpected task type {task_type}")
def init_model_cb(
@ -78,19 +75,3 @@ class LlamaModelParallelGenerator:
def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop()
def completion(
self,
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("completion", req_obj))
yield from gen
def chat_completion(
self,
request_batch: list[ChatCompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("chat_completion", req_obj))
yield from gen

View file

@ -33,10 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
log = get_logger(name=__name__, category="inference")
@ -69,10 +65,7 @@ class CancelSentinel(BaseModel):
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: tuple[
str,
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
]
task: tuple[str, list]
class TaskResponse(BaseModel):
@ -328,10 +321,7 @@ class ModelParallelProcessGroup:
def run_inference(
self,
req: tuple[
str,
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
],
req: tuple[str, list],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -22,9 +22,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
)
from .config import SentenceTransformersInferenceConfig
@ -32,7 +29,6 @@ log = get_logger(name=__name__, category="inference")
class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
InferenceProvider,
ModelsProtocolPrivate,

View file

@ -297,6 +297,20 @@ Available Models:
Azure OpenAI inference provider for accessing GPT models and other Azure services.
Provider documentation
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
""",
),
RemoteProviderSpec(
api=Api.inference,
provider_type="remote::oci",
adapter_type="oci",
pip_packages=["oci"],
module="llama_stack.providers.remote.inference.oci",
config_class="llama_stack.providers.remote.inference.oci.config.OCIConfig",
provider_data_validator="llama_stack.providers.remote.inference.oci.config.OCIProviderDataValidator",
description="""
Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models.
Provider documentation
https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm
""",
),
]

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import OCIConfig
async def get_adapter_impl(config: OCIConfig, _deps) -> InferenceProvider:
from .oci import OCIInferenceAdapter
adapter = OCIInferenceAdapter(config=config)
await adapter.initialize()
return adapter

View file

@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Generator, Mapping
from typing import Any, override
import httpx
import oci
import requests
from oci.config import DEFAULT_LOCATION, DEFAULT_PROFILE
OciAuthSigner = type[oci.signer.AbstractBaseSigner]
class HttpxOciAuth(httpx.Auth):
"""
Custom HTTPX authentication class that implements OCI request signing.
This class handles the authentication flow for HTTPX requests by signing them
using the OCI Signer, which adds the necessary authentication headers for
OCI API calls.
Attributes:
signer (oci.signer.Signer): The OCI signer instance used for request signing
"""
def __init__(self, signer: OciAuthSigner):
self.signer = signer
@override
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]:
# Read the request content to handle streaming requests properly
try:
content = request.content
except httpx.RequestNotRead:
# For streaming requests, we need to read the content first
content = request.read()
req = requests.Request(
method=request.method,
url=str(request.url),
headers=dict(request.headers),
data=content,
)
prepared_request = req.prepare()
# Sign the request using the OCI Signer
self.signer.do_request_sign(prepared_request) # type: ignore
# Update the original HTTPX request with the signed headers
request.headers.update(prepared_request.headers)
yield request
class OciInstancePrincipalAuth(HttpxOciAuth):
def __init__(self, **kwargs: Mapping[str, Any]):
self.signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner(**kwargs)
class OciUserPrincipalAuth(HttpxOciAuth):
def __init__(self, config_file: str = DEFAULT_LOCATION, profile_name: str = DEFAULT_PROFILE):
config = oci.config.from_file(config_file, profile_name)
oci.config.validate_config(config) # type: ignore
key_content = ""
with open(config["key_file"]) as f:
key_content = f.read()
self.signer = oci.signer.Signer(
tenancy=config["tenancy"],
user=config["user"],
fingerprint=config["fingerprint"],
private_key_file_location=config.get("key_file"),
pass_phrase="none", # type: ignore
private_key_content=key_content,
)

View file

@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class OCIProviderDataValidator(BaseModel):
oci_auth_type: str = Field(
description="OCI authentication type (must be one of: instance_principal, config_file)",
)
oci_region: str = Field(
description="OCI region (e.g., us-ashburn-1)",
)
oci_compartment_id: str = Field(
description="OCI compartment ID for the Generative AI service",
)
oci_config_file_path: str | None = Field(
default="~/.oci/config",
description="OCI config file path (required if oci_auth_type is config_file)",
)
oci_config_profile: str | None = Field(
default="DEFAULT",
description="OCI config profile (required if oci_auth_type is config_file)",
)
@json_schema_type
class OCIConfig(RemoteInferenceProviderConfig):
oci_auth_type: str = Field(
description="OCI authentication type (must be one of: instance_principal, config_file)",
default_factory=lambda: os.getenv("OCI_AUTH_TYPE", "instance_principal"),
)
oci_region: str = Field(
default_factory=lambda: os.getenv("OCI_REGION", "us-ashburn-1"),
description="OCI region (e.g., us-ashburn-1)",
)
oci_compartment_id: str = Field(
default_factory=lambda: os.getenv("OCI_COMPARTMENT_OCID", ""),
description="OCI compartment ID for the Generative AI service",
)
oci_config_file_path: str = Field(
default_factory=lambda: os.getenv("OCI_CONFIG_FILE_PATH", "~/.oci/config"),
description="OCI config file path (required if oci_auth_type is config_file)",
)
oci_config_profile: str = Field(
default_factory=lambda: os.getenv("OCI_CLI_PROFILE", "DEFAULT"),
description="OCI config profile (required if oci_auth_type is config_file)",
)
@classmethod
def sample_run_config(
cls,
oci_auth_type: str = "${env.OCI_AUTH_TYPE:=instance_principal}",
oci_config_file_path: str = "${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}",
oci_config_profile: str = "${env.OCI_CLI_PROFILE:=DEFAULT}",
oci_region: str = "${env.OCI_REGION:=us-ashburn-1}",
oci_compartment_id: str = "${env.OCI_COMPARTMENT_OCID:=}",
**kwargs,
) -> dict[str, Any]:
return {
"oci_auth_type": oci_auth_type,
"oci_config_file_path": oci_config_file_path,
"oci_config_profile": oci_config_profile,
"oci_region": oci_region,
"oci_compartment_id": oci_compartment_id,
}

View file

@ -0,0 +1,140 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Iterable
from typing import Any
import httpx
import oci
from oci.generative_ai.generative_ai_client import GenerativeAiClient
from oci.generative_ai.models import ModelCollection
from openai._base_client import DefaultAsyncHttpxClient
from llama_stack.apis.inference.inference import (
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth
from llama_stack.providers.remote.inference.oci.config import OCIConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
logger = get_logger(name=__name__, category="inference::oci")
OCI_AUTH_TYPE_INSTANCE_PRINCIPAL = "instance_principal"
OCI_AUTH_TYPE_CONFIG_FILE = "config_file"
VALID_OCI_AUTH_TYPES = [OCI_AUTH_TYPE_INSTANCE_PRINCIPAL, OCI_AUTH_TYPE_CONFIG_FILE]
DEFAULT_OCI_REGION = "us-ashburn-1"
MODEL_CAPABILITIES = ["TEXT_GENERATION", "TEXT_SUMMARIZATION", "TEXT_EMBEDDINGS", "CHAT"]
class OCIInferenceAdapter(OpenAIMixin):
config: OCIConfig
async def initialize(self) -> None:
"""Initialize and validate OCI configuration."""
if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES:
raise ValueError(
f"Invalid OCI authentication type: {self.config.oci_auth_type}."
f"Valid types are one of: {VALID_OCI_AUTH_TYPES}"
)
if not self.config.oci_compartment_id:
raise ValueError("OCI_COMPARTMENT_OCID is a required parameter. Either set in env variable or config.")
def get_base_url(self) -> str:
region = self.config.oci_region or DEFAULT_OCI_REGION
return f"https://inference.generativeai.{region}.oci.oraclecloud.com/20231130/actions/v1"
def get_api_key(self) -> str | None:
# OCI doesn't use API keys, it uses request signing
return "<NOTUSED>"
def get_extra_client_params(self) -> dict[str, Any]:
"""
Get extra parameters for the AsyncOpenAI client, including OCI-specific auth and headers.
"""
auth = self._get_auth()
compartment_id = self.config.oci_compartment_id or ""
return {
"http_client": DefaultAsyncHttpxClient(
auth=auth,
headers={
"CompartmentId": compartment_id,
},
),
}
def _get_oci_signer(self) -> oci.signer.AbstractBaseSigner | None:
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
return oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
return None
def _get_oci_config(self) -> dict:
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
config = {"region": self.config.oci_region}
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
config = oci.config.from_file(self.config.oci_config_file_path, self.config.oci_config_profile)
if not config.get("region"):
raise ValueError(
"Region not specified in config. Please specify in config or with OCI_REGION env variable."
)
return config
def _get_auth(self) -> httpx.Auth:
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
return OciInstancePrincipalAuth()
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
return OciUserPrincipalAuth(
config_file=self.config.oci_config_file_path, profile_name=self.config.oci_config_profile
)
else:
raise ValueError(f"Invalid OCI authentication type: {self.config.oci_auth_type}")
async def list_provider_model_ids(self) -> Iterable[str]:
"""
List available models from OCI Generative AI service.
"""
oci_config = self._get_oci_config()
oci_signer = self._get_oci_signer()
compartment_id = self.config.oci_compartment_id or ""
if oci_signer is None:
client = GenerativeAiClient(config=oci_config)
else:
client = GenerativeAiClient(config=oci_config, signer=oci_signer)
models: ModelCollection = client.list_models(
compartment_id=compartment_id, capability=MODEL_CAPABILITIES, lifecycle_state="ACTIVE"
).data
seen_models = set()
model_ids = []
for model in models.items:
if model.time_deprecated or model.time_on_demand_retired:
continue
if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities:
continue
# Use display_name + model_type as the key to avoid conflicts
model_key = (model.display_name, ModelType.llm)
if model_key in seen_models:
continue
seen_models.add(model_key)
model_ids.append(model.display_name)
return model_ids
async def openai_embeddings(self, params: OpenAIEmbeddingsRequestWithExtraBody) -> OpenAIEmbeddingsResponse:
# The constructed url is a mask that hits OCI's "chat" action, which is not supported for embeddings.
raise NotImplementedError("OCI Provider does not (currently) support embeddings")

View file

@ -11,9 +11,7 @@ from collections.abc import AsyncIterator
import litellm
from llama_stack.apis.inference import (
ChatCompletionRequest,
InferenceProvider,
JsonSchemaResponseFormat,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
@ -23,15 +21,11 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
ToolChoice,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict_new,
convert_tooldef_to_openai_tool,
get_sampling_options,
prepare_openai_completion_params,
)
@ -127,51 +121,6 @@ class LiteLLMOpenAIMixin(
return schema
async def _get_params(self, request: ChatCompletionRequest) -> dict:
from typing import Any
input_dict: dict[str, Any] = {}
input_dict["messages"] = [
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
]
if fmt := request.response_format:
if not isinstance(fmt, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
# Convert to dict for manipulation
fmt_dict = dict(fmt.json_schema)
name = fmt_dict["title"]
del fmt_dict["title"]
fmt_dict["additionalProperties"] = False
# Apply additionalProperties: False recursively to all objects
fmt_dict = self._add_additional_properties_recursive(fmt_dict)
input_dict["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": name,
"schema": fmt_dict,
"strict": self.json_schema_strict,
},
}
if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config and (tool_choice := request.tool_config.tool_choice):
input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice
return {
"model": request.model,
"api_key": self.get_api_key(),
"api_base": self.api_base,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
def get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field

File diff suppressed because it is too large Load diff

View file

@ -21,19 +21,18 @@ from llama_stack.apis.common.content_types import (
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIFile,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
ResponseFormat,
ResponseFormatType,
SystemMessage,
SystemMessageBehavior,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
@ -42,33 +41,19 @@ from llama_stack.models.llama.datatypes import (
RawMediaItem,
RawMessage,
RawTextItem,
Role,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
PythonListCustomToolGenerator,
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models
log = get_logger(name=__name__, category="providers::utils")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
messages: list[RawMessage]
class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent
@ -103,28 +88,6 @@ def interleaved_content_as_str(
return _process(content)
async def convert_request_to_raw(
request: ChatCompletionRequest | CompletionRequest,
) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent:
if isinstance(request, ChatCompletionRequest):
messages = []
for m in request.messages:
content = await interleaved_content_convert_to_raw(m.content)
d = m.model_dump()
d["content"] = content
messages.append(RawMessage(**d))
d = request.model_dump()
d["messages"] = messages
request = ChatCompletionRequestWithRawContent(**d)
else:
d = request.model_dump()
d["content"] = await interleaved_content_convert_to_raw(request.content)
request = CompletionRequestWithRawContent(**d)
return request
async def interleaved_content_convert_to_raw(
content: InterleavedContent,
) -> RawContent:
@ -171,6 +134,36 @@ async def interleaved_content_convert_to_raw(
return await _localize_single(content)
async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage:
"""Convert OpenAI message format to RawMessage format used by Llama formatters."""
if isinstance(message, OpenAIUserMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="user", content=content)
elif isinstance(message, OpenAISystemMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="system", content=content)
elif isinstance(message, OpenAIAssistantMessageParam):
content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type]
tool_calls = []
if message.tool_calls:
for tc in message.tool_calls:
if tc.function:
tool_calls.append(
ToolCall(
call_id=tc.id or "",
tool_name=tc.function.name or "",
arguments=tc.function.arguments or "{}",
)
)
return RawMessage(role="assistant", content=content, tool_calls=tool_calls)
elif isinstance(message, OpenAIToolMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="tool", content=content)
else:
# Handle OpenAIDeveloperMessageParam if needed
raise ValueError(f"Unsupported message type: {type(message)}")
def content_has_media(content: InterleavedContent):
def _has_media_content(c):
return isinstance(c, ImageContentItem)
@ -181,17 +174,6 @@ def content_has_media(content: InterleavedContent):
return _has_media_content(content)
def messages_have_media(messages: list[Message]):
return any(content_has_media(m.content) for m in messages)
def request_has_media(request: ChatCompletionRequest | CompletionRequest):
if isinstance(request, ChatCompletionRequest):
return messages_have_media(request.messages)
else:
return content_has_media(request.content)
async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
if uri.startswith("http"):
async with httpx.AsyncClient() as client:
@ -253,79 +235,6 @@ def augment_content_with_response_format_prompt(response_format, content):
return content
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
)
return formatter.tokenizer.decode(model_input.tokens)
async def chat_completion_request_to_model_input_info(
request: ChatCompletionRequest, llama_model: str
) -> tuple[str, int]:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
)
return (
formatter.tokenizer.decode(model_input.tokens),
len(model_input.tokens),
)
def chat_completion_request_to_messages(
request: ChatCompletionRequest,
llama_model: str,
) -> list[Message]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
"""
assert llama_model is not None, "llama_model is required"
model = resolve_model(llama_model)
if model is None:
log.error(f"Could not resolve model {llama_model}")
return request.messages
allowed_models = supported_inference_models()
descriptors = [m.descriptor() for m in allowed_models]
if model.descriptor() not in descriptors:
log.error(f"Unsupported inference model? {model.descriptor()}")
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request)
elif model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
):
# llama3.2, llama3.3 follow the same tool prompt format
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
elif model.model_family == ModelFamily.llama4:
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
else:
messages = request.messages
if fmt_prompt := response_format_prompt(request.response_format):
messages.append(UserMessage(content=fmt_prompt))
return messages
def response_format_prompt(fmt: ResponseFormat | None):
if not fmt:
return None
@ -338,128 +247,6 @@ def response_format_prompt(fmt: ResponseFormat | None):
raise ValueError(f"Unknown response format {fmt.type}")
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if request.tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(request.tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if existing_system_message:
# TODO: this fn is needed in many places
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
sys_content += "\n"
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
messages.append(SystemMessage(content=sys_content))
has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif fmt == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
# Add back existing messages from the request
messages += existing_messages
return messages
def augment_messages_for_tools_llama(
request: ChatCompletionRequest,
custom_tool_prompt_generator,
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
sys_content = ""
custom_tools, builtin_tools = [], []
for t in request.tools:
if isinstance(t.tool_name, str):
custom_tools.append(t)
else:
builtin_tools.append(t)
if builtin_tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(builtin_tools)
sys_content += tool_template.render()
sys_content += "\n"
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
system_prompt = None
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
system_prompt = existing_system_message.content
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
sys_content += tool_template.render()
sys_content += "\n"
if existing_system_message and (
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
):
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages]
return messages
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
if tool_choice == ToolChoice.auto:
return ""

View file

@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreFileBatchObject,
VectorStoreFileContentsResponse,
VectorStoreFileContentResponse,
VectorStoreFileCounts,
VectorStoreFileDeleteResponse,
VectorStoreFileLastError,
@ -704,34 +704,35 @@ class OpenAIVectorStoreMixin(ABC):
# Unknown filter type, default to no match
raise ValueError(f"Unsupported filter type: {filter_type}")
def _chunk_to_vector_store_content(self, chunk: Chunk) -> list[VectorStoreContent]:
# content is InterleavedContent
def _chunk_to_vector_store_content(
self, chunk: Chunk, include_embeddings: bool = False, include_metadata: bool = False
) -> list[VectorStoreContent]:
def extract_fields() -> dict:
"""Extract embedding and metadata fields from chunk based on include flags."""
return {
"embedding": chunk.embedding if include_embeddings else None,
"chunk_metadata": chunk.chunk_metadata if include_metadata else None,
"metadata": chunk.metadata if include_metadata else None,
}
fields = extract_fields()
if isinstance(chunk.content, str):
content = [
VectorStoreContent(
type="text",
text=chunk.content,
)
]
content_item = VectorStoreContent(type="text", text=chunk.content, **fields)
content = [content_item]
elif isinstance(chunk.content, list):
# TODO: Add support for other types of content
content = [
VectorStoreContent(
type="text",
text=item.text,
)
for item in chunk.content
if item.type == "text"
]
content = []
for item in chunk.content:
if item.type == "text":
content_item = VectorStoreContent(type="text", text=item.text, **fields)
content.append(content_item)
else:
if chunk.content.type != "text":
raise ValueError(f"Unsupported content type: {chunk.content.type}")
content = [
VectorStoreContent(
type="text",
text=chunk.content.text,
)
]
content_item = VectorStoreContent(type="text", text=chunk.content.text, **fields)
content = [content_item]
return content
async def openai_attach_file_to_vector_store(
@ -820,13 +821,12 @@ class OpenAIVectorStoreMixin(ABC):
message=str(e),
)
# Create OpenAI vector store file metadata
# Save vector store file to persistent storage AFTER insert_chunks
# so that chunks include the embeddings that were generated
file_info = vector_store_file_object.model_dump(exclude={"last_error"})
file_info["filename"] = file_response.filename if file_response else ""
# Save vector store file to persistent storage (provider-specific)
dict_chunks = [c.model_dump() for c in chunks]
# This should be updated to include chunk_id
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info, dict_chunks)
# Update file_ids and file_counts in vector store metadata
@ -921,22 +921,27 @@ class OpenAIVectorStoreMixin(ABC):
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
include_embeddings: bool | None = False,
include_metadata: bool | None = False,
) -> VectorStoreFileContentResponse:
"""Retrieves the contents of a vector store file."""
if vector_store_id not in self.openai_vector_stores:
raise VectorStoreNotFoundError(vector_store_id)
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
# Parameters are already provided directly
# include_embeddings and include_metadata are now function parameters
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
chunks = [Chunk.model_validate(c) for c in dict_chunks]
content = []
for chunk in chunks:
content.extend(self._chunk_to_vector_store_content(chunk))
return VectorStoreFileContentsResponse(
file_id=file_id,
filename=file_info.get("filename", ""),
attributes=file_info.get("attributes", {}),
content=content,
content.extend(
self._chunk_to_vector_store_content(
chunk, include_embeddings=include_embeddings or False, include_metadata=include_metadata or False
)
)
return VectorStoreFileContentResponse(
data=content,
)
async def openai_update_vector_store_file(

View file

@ -0,0 +1,20 @@
.git
.gitignore
.env.local
.env.*.local
.next
node_modules
npm-debug.log
*.md
.DS_Store
.vscode
.idea
playwright-report
e2e
jest.config.ts
jest.setup.ts
eslint.config.mjs
.prettierrc
.prettierignore
.nvmrc
playwright.config.ts

View file

@ -0,0 +1,18 @@
FROM node:22.5.1-alpine
ENV NODE_ENV=production
# Install dumb-init for proper signal handling
RUN apk add --no-cache dumb-init
# Create non-root user for security
RUN addgroup --system --gid 1001 nodejs
RUN adduser --system --uid 1001 nextjs
# Install llama-stack-ui from npm
RUN npm install -g llama-stack-ui
USER nextjs
ENTRYPOINT ["dumb-init", "--"]
CMD ["llama-stack-ui"]

View file

@ -8,6 +8,9 @@ import type {
import { useRouter } from "next/navigation";
import { usePagination } from "@/hooks/use-pagination";
import { Button } from "@/components/ui/button";
import { Plus, Trash2, Search, Edit, X } from "lucide-react";
import { useState } from "react";
import { Input } from "@/components/ui/input";
import {
Table,
TableBody,
@ -17,9 +20,21 @@ import {
TableRow,
} from "@/components/ui/table";
import { Skeleton } from "@/components/ui/skeleton";
import { useAuthClient } from "@/hooks/use-auth-client";
import {
VectorStoreEditor,
VectorStoreFormData,
} from "@/components/vector-stores/vector-store-editor";
export default function VectorStoresPage() {
const router = useRouter();
const client = useAuthClient();
const [deletingStores, setDeletingStores] = useState<Set<string>>(new Set());
const [searchTerm, setSearchTerm] = useState("");
const [showVectorStoreModal, setShowVectorStoreModal] = useState(false);
const [editingStore, setEditingStore] = useState<VectorStore | null>(null);
const [modalError, setModalError] = useState<string | null>(null);
const [showSuccessState, setShowSuccessState] = useState(false);
const {
data: stores,
status,
@ -47,6 +62,142 @@ export default function VectorStoresPage() {
}
}, [status, hasMore, loadMore]);
// Handle ESC key to close modal
React.useEffect(() => {
const handleEscape = (event: KeyboardEvent) => {
if (event.key === "Escape" && showVectorStoreModal) {
handleCancel();
}
};
document.addEventListener("keydown", handleEscape);
return () => document.removeEventListener("keydown", handleEscape);
}, [showVectorStoreModal]);
const handleDeleteVectorStore = async (storeId: string) => {
if (
!confirm(
"Are you sure you want to delete this vector store? This action cannot be undone."
)
) {
return;
}
setDeletingStores(prev => new Set([...prev, storeId]));
try {
await client.vectorStores.delete(storeId);
// Reload the data to reflect the deletion
window.location.reload();
} catch (err: unknown) {
console.error("Failed to delete vector store:", err);
const errorMessage = err instanceof Error ? err.message : "Unknown error";
alert(`Failed to delete vector store: ${errorMessage}`);
} finally {
setDeletingStores(prev => {
const newSet = new Set(prev);
newSet.delete(storeId);
return newSet;
});
}
};
const handleSaveVectorStore = async (formData: VectorStoreFormData) => {
try {
setModalError(null);
if (editingStore) {
// Update existing vector store
const updateParams: {
name?: string;
extra_body?: Record<string, unknown>;
} = {};
// Only include fields that have changed or are provided
if (formData.name && formData.name !== editingStore.name) {
updateParams.name = formData.name;
}
// Add all parameters to extra_body (except provider_id which can't be changed)
const extraBody: Record<string, unknown> = {};
if (formData.embedding_model) {
extraBody.embedding_model = formData.embedding_model;
}
if (formData.embedding_dimension) {
extraBody.embedding_dimension = formData.embedding_dimension;
}
if (Object.keys(extraBody).length > 0) {
updateParams.extra_body = extraBody;
}
await client.vectorStores.update(editingStore.id, updateParams);
// Show success state with close button
setShowSuccessState(true);
setModalError(
"✅ Vector store updated successfully! You can close this modal and refresh the page to see changes."
);
return;
}
const createParams: {
name?: string;
provider_id?: string;
extra_body?: Record<string, unknown>;
} = {
name: formData.name || undefined,
};
// Extract provider_id to top-level (like Python client does)
if (formData.provider_id) {
createParams.provider_id = formData.provider_id;
}
// Add remaining parameters to extra_body
const extraBody: Record<string, unknown> = {};
if (formData.provider_id) {
extraBody.provider_id = formData.provider_id;
}
if (formData.embedding_model) {
extraBody.embedding_model = formData.embedding_model;
}
if (formData.embedding_dimension) {
extraBody.embedding_dimension = formData.embedding_dimension;
}
if (Object.keys(extraBody).length > 0) {
createParams.extra_body = extraBody;
}
await client.vectorStores.create(createParams);
// Show success state with close button
setShowSuccessState(true);
setModalError(
"✅ Vector store created successfully! You can close this modal and refresh the page to see changes."
);
} catch (err: unknown) {
console.error("Failed to create vector store:", err);
const errorMessage =
err instanceof Error ? err.message : "Failed to create vector store";
setModalError(errorMessage);
}
};
const handleEditVectorStore = (store: VectorStore) => {
setEditingStore(store);
setShowVectorStoreModal(true);
setModalError(null);
};
const handleCancel = () => {
setShowVectorStoreModal(false);
setEditingStore(null);
setModalError(null);
setShowSuccessState(false);
};
const renderContent = () => {
if (status === "loading") {
return (
@ -66,7 +217,38 @@ export default function VectorStoresPage() {
return <p>No vector stores found.</p>;
}
// Filter stores based on search term
const filteredStores = stores.filter(store => {
if (!searchTerm) return true;
const searchLower = searchTerm.toLowerCase();
return (
store.id.toLowerCase().includes(searchLower) ||
(store.name && store.name.toLowerCase().includes(searchLower)) ||
(store.metadata?.provider_id &&
String(store.metadata.provider_id)
.toLowerCase()
.includes(searchLower)) ||
(store.metadata?.provider_vector_db_id &&
String(store.metadata.provider_vector_db_id)
.toLowerCase()
.includes(searchLower))
);
});
return (
<div className="space-y-4">
{/* Search Bar */}
<div className="relative flex-1 max-w-md">
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 text-muted-foreground h-4 w-4" />
<Input
placeholder="Search vector stores..."
value={searchTerm}
onChange={e => setSearchTerm(e.target.value)}
className="pl-10"
/>
</div>
<div className="overflow-auto flex-1 min-h-0">
<Table>
<TableHeader>
@ -82,10 +264,11 @@ export default function VectorStoresPage() {
<TableHead>Usage Bytes</TableHead>
<TableHead>Provider ID</TableHead>
<TableHead>Provider Vector DB ID</TableHead>
<TableHead>Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{stores.map(store => {
{filteredStores.map(store => {
const fileCounts = store.file_counts;
const metadata = store.metadata || {};
const providerId = metadata.provider_id ?? "";
@ -94,7 +277,9 @@ export default function VectorStoresPage() {
return (
<TableRow
key={store.id}
onClick={() => router.push(`/logs/vector-stores/${store.id}`)}
onClick={() =>
router.push(`/logs/vector-stores/${store.id}`)
}
className="cursor-pointer hover:bg-muted/50"
>
<TableCell>
@ -120,19 +305,102 @@ export default function VectorStoresPage() {
<TableCell>{store.usage_bytes}</TableCell>
<TableCell>{providerId}</TableCell>
<TableCell>{providerDbId}</TableCell>
<TableCell>
<div className="flex gap-2">
<Button
variant="outline"
size="sm"
onClick={e => {
e.stopPropagation();
handleEditVectorStore(store);
}}
>
<Edit className="h-4 w-4" />
</Button>
<Button
variant="outline"
size="sm"
onClick={e => {
e.stopPropagation();
handleDeleteVectorStore(store.id);
}}
disabled={deletingStores.has(store.id)}
>
{deletingStores.has(store.id) ? (
"Deleting..."
) : (
<>
<Trash2 className="h-4 w-4" />
</>
)}
</Button>
</div>
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</div>
</div>
);
};
return (
<div className="space-y-4">
<div className="flex items-center justify-between">
<h1 className="text-2xl font-semibold">Vector Stores</h1>
<Button
onClick={() => setShowVectorStoreModal(true)}
disabled={status === "loading"}
>
<Plus className="h-4 w-4 mr-2" />
New Vector Store
</Button>
</div>
{renderContent()}
{/* Create Vector Store Modal */}
{showVectorStoreModal && (
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
<div className="bg-background border rounded-lg shadow-lg max-w-2xl w-full mx-4 max-h-[90vh] overflow-hidden">
<div className="p-6 border-b flex items-center justify-between">
<h2 className="text-2xl font-bold">
{editingStore ? "Edit Vector Store" : "Create New Vector Store"}
</h2>
<Button
variant="ghost"
size="sm"
onClick={handleCancel}
className="p-1 h-auto"
>
<X className="h-4 w-4" />
</Button>
</div>
<div className="p-6 overflow-y-auto max-h-[calc(90vh-120px)]">
<VectorStoreEditor
onSave={handleSaveVectorStore}
onCancel={handleCancel}
error={modalError}
showSuccessState={showSuccessState}
isEditing={!!editingStore}
initialData={
editingStore
? {
name: editingStore.name || "",
embedding_model:
editingStore.metadata?.embedding_model || "",
embedding_dimension:
editingStore.metadata?.embedding_dimension || 768,
provider_id: editingStore.metadata?.provider_id || "",
}
: undefined
}
/>
</div>
</div>
</div>
)}
</div>
);
}

34
src/llama_stack_ui/bin/cli.js Executable file
View file

@ -0,0 +1,34 @@
#!/usr/bin/env node
const { spawn } = require('child_process');
const path = require('path');
const port = process.env.LLAMA_STACK_UI_PORT || 8322;
const uiDir = path.resolve(__dirname, '..');
const serverPath = path.join(uiDir, '.next', 'standalone', 'ui', 'src', 'llama_stack_ui', 'server.js');
const serverDir = path.dirname(serverPath);
console.log(`Starting Llama Stack UI on http://localhost:${port}`);
const child = spawn(process.execPath, [serverPath], {
cwd: serverDir,
stdio: 'inherit',
env: {
...process.env,
PORT: port,
},
});
process.on('SIGINT', () => {
child.kill('SIGINT');
process.exit(0);
});
process.on('SIGTERM', () => {
child.kill('SIGTERM');
process.exit(0);
});
child.on('exit', (code) => {
process.exit(code);
});

View file

@ -2,7 +2,7 @@ import React from "react";
import { render, screen, fireEvent } from "@testing-library/react";
import "@testing-library/jest-dom";
import { PromptEditor } from "./prompt-editor";
import type { Prompt, PromptFormData } from "./types";
import type { Prompt } from "./types";
describe("PromptEditor", () => {
const mockOnSave = jest.fn();

View file

@ -12,6 +12,20 @@ jest.mock("next/navigation", () => ({
}),
}));
// Mock NextAuth
jest.mock("next-auth/react", () => ({
useSession: () => ({
data: {
accessToken: "mock-access-token",
user: {
id: "mock-user-id",
email: "test@example.com",
},
},
status: "authenticated",
}),
}));
describe("VectorStoreDetailView", () => {
const defaultProps = {
store: null,

View file

@ -1,16 +1,18 @@
"use client";
import { useRouter } from "next/navigation";
import { useState, useEffect } from "react";
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
import { Button } from "@/components/ui/button";
import { useAuthClient } from "@/hooks/use-auth-client";
import { Edit2, Trash2, X } from "lucide-react";
import {
DetailLoadingView,
DetailErrorView,
DetailNotFoundView,
DetailLayout,
PropertiesCard,
PropertyItem,
} from "@/components/layout/detail-layout";
@ -23,6 +25,7 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table";
import { VectorStoreEditor, VectorStoreFormData } from "./vector-store-editor";
interface VectorStoreDetailViewProps {
store: VectorStore | null;
@ -43,21 +46,122 @@ export function VectorStoreDetailView({
errorFiles,
id,
}: VectorStoreDetailViewProps) {
const title = "Vector Store Details";
const router = useRouter();
const client = useAuthClient();
const [isDeleting, setIsDeleting] = useState(false);
const [showEditModal, setShowEditModal] = useState(false);
const [modalError, setModalError] = useState<string | null>(null);
const [showSuccessState, setShowSuccessState] = useState(false);
// Handle ESC key to close modal
useEffect(() => {
const handleEscape = (event: KeyboardEvent) => {
if (event.key === "Escape" && showEditModal) {
handleCancel();
}
};
document.addEventListener("keydown", handleEscape);
return () => document.removeEventListener("keydown", handleEscape);
}, [showEditModal]);
const handleFileClick = (fileId: string) => {
router.push(`/logs/vector-stores/${id}/files/${fileId}`);
};
const handleEditVectorStore = () => {
setShowEditModal(true);
setModalError(null);
setShowSuccessState(false);
};
const handleCancel = () => {
setShowEditModal(false);
setModalError(null);
setShowSuccessState(false);
};
const handleSaveVectorStore = async (formData: VectorStoreFormData) => {
try {
setModalError(null);
// Update existing vector store (same logic as list page)
const updateParams: {
name?: string;
extra_body?: Record<string, unknown>;
} = {};
// Only include fields that have changed or are provided
if (formData.name && formData.name !== store?.name) {
updateParams.name = formData.name;
}
// Add all parameters to extra_body (except provider_id which can't be changed)
const extraBody: Record<string, unknown> = {};
if (formData.embedding_model) {
extraBody.embedding_model = formData.embedding_model;
}
if (formData.embedding_dimension) {
extraBody.embedding_dimension = formData.embedding_dimension;
}
if (Object.keys(extraBody).length > 0) {
updateParams.extra_body = extraBody;
}
await client.vectorStores.update(id, updateParams);
// Show success state
setShowSuccessState(true);
setModalError(
"✅ Vector store updated successfully! You can close this modal and refresh the page to see changes."
);
} catch (err: unknown) {
console.error("Failed to update vector store:", err);
const errorMessage =
err instanceof Error ? err.message : "Failed to update vector store";
setModalError(errorMessage);
}
};
const handleDeleteVectorStore = async () => {
if (
!confirm(
"Are you sure you want to delete this vector store? This action cannot be undone."
)
) {
return;
}
setIsDeleting(true);
try {
await client.vectorStores.delete(id);
// Redirect to the vector stores list after successful deletion
router.push("/logs/vector-stores");
} catch (err: unknown) {
console.error("Failed to delete vector store:", err);
const errorMessage = err instanceof Error ? err.message : "Unknown error";
alert(`Failed to delete vector store: ${errorMessage}`);
} finally {
setIsDeleting(false);
}
};
if (errorStore) {
return <DetailErrorView title={title} id={id} error={errorStore} />;
return (
<DetailErrorView
title="Vector Store Details"
id={id}
error={errorStore}
/>
);
}
if (isLoadingStore) {
return <DetailLoadingView title={title} />;
return <DetailLoadingView />;
}
if (!store) {
return <DetailNotFoundView title={title} id={id} />;
return <DetailNotFoundView title="Vector Store Details" id={id} />;
}
const mainContent = (
@ -138,6 +242,73 @@ export function VectorStoreDetailView({
);
return (
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
<>
<div className="flex items-center justify-between mb-6">
<h1 className="text-2xl font-bold">Vector Store Details</h1>
<div className="flex gap-2">
<Button
variant="outline"
onClick={handleEditVectorStore}
disabled={isDeleting}
>
<Edit2 className="h-4 w-4 mr-2" />
Edit
</Button>
<Button
variant="destructive"
onClick={handleDeleteVectorStore}
disabled={isDeleting}
>
{isDeleting ? (
"Deleting..."
) : (
<>
<Trash2 className="h-4 w-4 mr-2" />
Delete
</>
)}
</Button>
</div>
</div>
<div className="flex flex-col md:flex-row gap-6">
<div className="flex-grow md:w-2/3 space-y-6">{mainContent}</div>
<div className="md:w-1/3">{sidebar}</div>
</div>
{/* Edit Vector Store Modal */}
{showEditModal && (
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
<div className="bg-background border rounded-lg shadow-lg max-w-2xl w-full mx-4 max-h-[90vh] overflow-hidden">
<div className="p-6 border-b flex items-center justify-between">
<h2 className="text-2xl font-bold">Edit Vector Store</h2>
<Button
variant="ghost"
size="sm"
onClick={handleCancel}
className="p-1 h-auto"
>
<X className="h-4 w-4" />
</Button>
</div>
<div className="p-6 overflow-y-auto max-h-[calc(90vh-120px)]">
<VectorStoreEditor
onSave={handleSaveVectorStore}
onCancel={handleCancel}
error={modalError}
showSuccessState={showSuccessState}
isEditing={true}
initialData={{
name: store?.name || "",
embedding_model: store?.metadata?.embedding_model || "",
embedding_dimension:
store?.metadata?.embedding_dimension || 768,
provider_id: store?.metadata?.provider_id || "",
}}
/>
</div>
</div>
</div>
)}
</>
);
}

View file

@ -0,0 +1,235 @@
"use client";
import { useState, useEffect } from "react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { Card, CardContent } from "@/components/ui/card";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { useAuthClient } from "@/hooks/use-auth-client";
import type { Model } from "llama-stack-client/resources/models";
export interface VectorStoreFormData {
name: string;
embedding_model?: string;
embedding_dimension?: number;
provider_id?: string;
}
interface VectorStoreEditorProps {
onSave: (formData: VectorStoreFormData) => Promise<void>;
onCancel: () => void;
error?: string | null;
initialData?: VectorStoreFormData;
showSuccessState?: boolean;
isEditing?: boolean;
}
export function VectorStoreEditor({
onSave,
onCancel,
error,
initialData,
showSuccessState,
isEditing = false,
}: VectorStoreEditorProps) {
const client = useAuthClient();
const [formData, setFormData] = useState<VectorStoreFormData>(
initialData || {
name: "",
embedding_model: "",
embedding_dimension: 768,
provider_id: "",
}
);
const [loading, setLoading] = useState(false);
const [models, setModels] = useState<Model[]>([]);
const [modelsLoading, setModelsLoading] = useState(true);
const [modelsError, setModelsError] = useState<string | null>(null);
const embeddingModels = models.filter(
model => model.custom_metadata?.model_type === "embedding"
);
useEffect(() => {
const fetchModels = async () => {
try {
setModelsLoading(true);
setModelsError(null);
const modelList = await client.models.list();
setModels(modelList);
// Set default embedding model if available
const embeddingModelsList = modelList.filter(model => {
return model.custom_metadata?.model_type === "embedding";
});
if (embeddingModelsList.length > 0 && !formData.embedding_model) {
setFormData(prev => ({
...prev,
embedding_model: embeddingModelsList[0].id,
}));
}
} catch (err) {
console.error("Failed to load models:", err);
setModelsError(
err instanceof Error ? err.message : "Failed to load models"
);
} finally {
setModelsLoading(false);
}
};
fetchModels();
}, [client]);
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
setLoading(true);
try {
await onSave(formData);
} finally {
setLoading(false);
}
};
return (
<Card>
<CardContent className="pt-6">
<form onSubmit={handleSubmit} className="space-y-4">
<div className="space-y-2">
<Label htmlFor="name">Name</Label>
<Input
id="name"
value={formData.name}
onChange={e => setFormData({ ...formData, name: e.target.value })}
placeholder="Enter vector store name"
required
/>
</div>
<div className="space-y-2">
<Label htmlFor="embedding_model">Embedding Model (Optional)</Label>
{modelsLoading ? (
<div className="text-sm text-muted-foreground">
Loading models... ({models.length} loaded)
</div>
) : modelsError ? (
<div className="text-sm text-destructive">
Error: {modelsError}
</div>
) : embeddingModels.length === 0 ? (
<div className="text-sm text-muted-foreground">
No embedding models available ({models.length} total models)
</div>
) : (
<Select
value={formData.embedding_model}
onValueChange={value =>
setFormData({ ...formData, embedding_model: value })
}
>
<SelectTrigger>
<SelectValue placeholder="Select Embedding Model" />
</SelectTrigger>
<SelectContent>
{embeddingModels.map((model, index) => (
<SelectItem key={model.id} value={model.id}>
{model.id}
</SelectItem>
))}
</SelectContent>
</Select>
)}
{formData.embedding_model && (
<p className="text-xs text-muted-foreground mt-1">
Dimension:{" "}
{embeddingModels.find(m => m.id === formData.embedding_model)
?.custom_metadata?.embedding_dimension || "Unknown"}
</p>
)}
</div>
<div className="space-y-2">
<Label htmlFor="embedding_dimension">
Embedding Dimension (Optional)
</Label>
<Input
id="embedding_dimension"
type="number"
value={formData.embedding_dimension}
onChange={e =>
setFormData({
...formData,
embedding_dimension: parseInt(e.target.value) || 768,
})
}
placeholder="768"
/>
</div>
<div className="space-y-2">
<Label htmlFor="provider_id">
Provider ID {isEditing ? "(Read-only)" : "(Optional)"}
</Label>
<Input
id="provider_id"
value={formData.provider_id}
onChange={e =>
setFormData({ ...formData, provider_id: e.target.value })
}
placeholder="e.g., faiss, chroma, sqlite"
disabled={isEditing}
/>
{isEditing && (
<p className="text-xs text-muted-foreground">
Provider ID cannot be changed after vector store creation
</p>
)}
</div>
{error && (
<div
className={`text-sm p-3 rounded ${
error.startsWith("✅")
? "text-green-700 bg-green-50 border border-green-200"
: "text-destructive bg-destructive/10"
}`}
>
{error}
</div>
)}
<div className="flex gap-2 pt-4">
{showSuccessState ? (
<Button type="button" onClick={onCancel}>
Close
</Button>
) : (
<>
<Button type="submit" disabled={loading}>
{loading
? initialData
? "Updating..."
: "Creating..."
: initialData
? "Update Vector Store"
: "Create Vector Store"}
</Button>
<Button type="button" variant="outline" onClick={onCancel}>
Cancel
</Button>
</>
)}
</div>
</form>
</CardContent>
</Card>
);
}

View file

@ -34,9 +34,35 @@ export class ContentsAPI {
async getFileContents(
vectorStoreId: string,
fileId: string
fileId: string,
includeEmbeddings: boolean = true,
includeMetadata: boolean = true
): Promise<VectorStoreContentsResponse> {
return this.client.vectorStores.files.content(vectorStoreId, fileId);
try {
// Use query parameters to pass embeddings and metadata flags (OpenAI-compatible pattern)
const extraQuery: Record<string, boolean> = {};
if (includeEmbeddings) {
extraQuery.include_embeddings = true;
}
if (includeMetadata) {
extraQuery.include_metadata = true;
}
const result = await this.client.vectorStores.files.content(
vectorStoreId,
fileId,
{
query: {
include_embeddings: includeEmbeddings,
include_metadata: includeMetadata,
},
}
);
return result;
} catch (error) {
console.error("ContentsAPI.getFileContents error:", error);
throw error;
}
}
async getContent(
@ -70,11 +96,15 @@ export class ContentsAPI {
order?: string;
after?: string;
before?: string;
includeEmbeddings?: boolean;
includeMetadata?: boolean;
}
): Promise<VectorStoreListContentsResponse> {
const fileContents = await this.client.vectorStores.files.content(
const fileContents = await this.getFileContents(
vectorStoreId,
fileId
fileId,
options?.includeEmbeddings ?? true,
options?.includeMetadata ?? true
);
const contentItems: VectorStoreContentItem[] = [];
@ -82,7 +112,7 @@ export class ContentsAPI {
const rawContent = content as Record<string, unknown>;
// Extract actual fields from the API response
const embedding = rawContent.embedding || undefined;
const embedding = rawContent.embedding as number[] | undefined;
const created_timestamp =
rawContent.created_timestamp ||
rawContent.created_at ||

View file

@ -1,7 +1,13 @@
import type { NextConfig } from "next";
const nextConfig: NextConfig = {
/* config options here */
typescript: {
ignoreBuildErrors: true,
},
output: "standalone",
images: {
unoptimized: true,
},
};
export default nextConfig;

View file

@ -1,12 +1,13 @@
{
"name": "ui",
"version": "0.1.0",
"name": "llama-stack-ui",
"version": "0.4.0-alpha.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "ui",
"version": "0.1.0",
"name": "llama-stack-ui",
"version": "0.4.0-alpha.1",
"license": "MIT",
"dependencies": {
"@radix-ui/react-collapsible": "^1.1.12",
"@radix-ui/react-dialog": "^1.1.15",
@ -20,7 +21,7 @@
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"framer-motion": "^12.23.24",
"llama-stack-client": "github:llamastack/llama-stack-client-typescript",
"llama-stack-client": "^0.3.1",
"lucide-react": "^0.545.0",
"next": "15.5.4",
"next-auth": "^4.24.11",
@ -9684,8 +9685,9 @@
"license": "MIT"
},
"node_modules/llama-stack-client": {
"version": "0.4.0-alpha.1",
"resolved": "git+ssh://git@github.com/llamastack/llama-stack-client-typescript.git#78de4862c4b7d77939ac210fa9f9bde77a2c5c5f",
"version": "0.3.1",
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.3.1.tgz",
"integrity": "sha512-4aYoF2aAQiBSfxyZEtczeQmJn8q9T22ePDqGhR+ej5RG6a8wvl5B3v7ZoKuFkft+vcP/kbJ58GQZEPLekxekZA==",
"license": "MIT",
"dependencies": {
"@types/node": "^18.11.18",

View file

@ -1,11 +1,31 @@
{
"name": "ui",
"version": "0.1.0",
"private": true,
"name": "llama-stack-ui",
"version": "0.4.0-alpha.4",
"description": "Web UI for Llama Stack",
"license": "MIT",
"author": "Llama Stack <llamastack@meta.com>",
"repository": {
"type": "git",
"url": "https://github.com/llamastack/llama-stack.git",
"directory": "llama_stack_ui"
},
"bin": {
"llama-stack-ui": "bin/cli.js"
},
"files": [
"bin",
".next",
"public",
"next.config.ts",
"instrumentation.ts",
"tsconfig.json",
"package.json"
],
"scripts": {
"dev": "next dev --turbopack --port ${LLAMA_STACK_UI_PORT:-8322}",
"build": "next build",
"build": "next build && node scripts/postbuild.js",
"start": "next start",
"prepublishOnly": "npm run build",
"lint": "next lint",
"format": "prettier --write \"./**/*.{ts,tsx}\"",
"format:check": "prettier --check \"./**/*.{ts,tsx}\"",
@ -25,7 +45,7 @@
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"framer-motion": "^12.23.24",
"llama-stack-client": "github:llamastack/llama-stack-client-typescript",
"llama-stack-client": "^0.3.1",
"lucide-react": "^0.545.0",
"next": "15.5.4",
"next-auth": "^4.24.11",

View file

@ -0,0 +1,40 @@
const fs = require('fs');
const path = require('path');
// Copy public directory to standalone
const publicSrc = path.join(__dirname, '..', 'public');
const publicDest = path.join(__dirname, '..', '.next', 'standalone', 'ui', 'src', 'llama_stack_ui', 'public');
if (fs.existsSync(publicSrc) && !fs.existsSync(publicDest)) {
console.log('Copying public directory to standalone...');
copyDir(publicSrc, publicDest);
}
// Copy .next/static to standalone
const staticSrc = path.join(__dirname, '..', '.next', 'static');
const staticDest = path.join(__dirname, '..', '.next', 'standalone', 'ui', 'src', 'llama_stack_ui', '.next', 'static');
if (fs.existsSync(staticSrc) && !fs.existsSync(staticDest)) {
console.log('Copying .next/static to standalone...');
copyDir(staticSrc, staticDest);
}
function copyDir(src, dest) {
if (!fs.existsSync(dest)) {
fs.mkdirSync(dest, { recursive: true });
}
const files = fs.readdirSync(src);
files.forEach((file) => {
const srcFile = path.join(src, file);
const destFile = path.join(dest, file);
if (fs.statSync(srcFile).isDirectory()) {
copyDir(srcFile, destFile);
} else {
fs.copyFileSync(srcFile, destFile);
}
});
}
console.log('Postbuild complete!');

View file

@ -516,3 +516,169 @@ def test_response_with_instructions(openai_client, client_with_models, text_mode
# Verify instructions from previous response was not carried over to the next response
assert response_with_instructions2.instructions == instructions2
@pytest.mark.skip(reason="Tool calling is not reliable.")
def test_max_tool_calls_with_function_tools(openai_client, client_with_models, text_model_id):
"""Test handling of max_tool_calls with function tools in responses."""
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
client = openai_client
max_tool_calls = 1
tools = [
{
"type": "function",
"name": "get_weather",
"description": "Get weather information for a specified location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name (e.g., 'New York', 'London')",
},
},
},
},
{
"type": "function",
"name": "get_time",
"description": "Get current time for a specified location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name (e.g., 'New York', 'London')",
},
},
},
},
]
# First create a response that triggers function tools
response = client.responses.create(
model=text_model_id,
input="Can you tell me the weather in Paris and the current time?",
tools=tools,
stream=False,
max_tool_calls=max_tool_calls,
)
# Verify we got two function calls and that the max_tool_calls do not affect function tools
assert len(response.output) == 2
assert response.output[0].type == "function_call"
assert response.output[0].name == "get_weather"
assert response.output[0].status == "completed"
assert response.output[1].type == "function_call"
assert response.output[1].name == "get_time"
assert response.output[0].status == "completed"
# Verify we have a valid max_tool_calls field
assert response.max_tool_calls == max_tool_calls
def test_max_tool_calls_invalid(openai_client, client_with_models, text_model_id):
"""Test handling of invalid max_tool_calls in responses."""
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
client = openai_client
input = "Search for today's top technology news."
invalid_max_tool_calls = 0
tools = [
{"type": "web_search"},
]
# Create a response with an invalid max_tool_calls value i.e. 0
# Handle ValueError from LLS and BadRequestError from OpenAI client
with pytest.raises((ValueError, BadRequestError)) as excinfo:
client.responses.create(
model=text_model_id,
input=input,
tools=tools,
stream=False,
max_tool_calls=invalid_max_tool_calls,
)
error_message = str(excinfo.value)
assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, (
f"Expected error message about invalid max_tool_calls, got: {error_message}"
)
def test_max_tool_calls_with_builtin_tools(openai_client, client_with_models, text_model_id):
"""Test handling of max_tool_calls with built-in tools in responses."""
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
client = openai_client
input = "Search for today's top technology and a positive news story. You MUST make exactly two separate web search calls."
max_tool_calls = [1, 5]
tools = [
{"type": "web_search"},
]
# First create a response that triggers web_search tools without max_tool_calls
response = client.responses.create(
model=text_model_id,
input=input,
tools=tools,
stream=False,
)
# Verify we got two web search calls followed by a message
assert len(response.output) == 3
assert response.output[0].type == "web_search_call"
assert response.output[0].status == "completed"
assert response.output[1].type == "web_search_call"
assert response.output[1].status == "completed"
assert response.output[2].type == "message"
assert response.output[2].status == "completed"
assert response.output[2].role == "assistant"
# Next create a response that triggers web_search tools with max_tool_calls set to 1
response_2 = client.responses.create(
model=text_model_id,
input=input,
tools=tools,
stream=False,
max_tool_calls=max_tool_calls[0],
)
# Verify we got one web search tool call followed by a message
assert len(response_2.output) == 2
assert response_2.output[0].type == "web_search_call"
assert response_2.output[0].status == "completed"
assert response_2.output[1].type == "message"
assert response_2.output[1].status == "completed"
assert response_2.output[1].role == "assistant"
# Verify we have a valid max_tool_calls field
assert response_2.max_tool_calls == max_tool_calls[0]
# Finally create a response that triggers web_search tools with max_tool_calls set to 5
response_3 = client.responses.create(
model=text_model_id,
input=input,
tools=tools,
stream=False,
max_tool_calls=max_tool_calls[1],
)
# Verify we got two web search calls followed by a message
assert len(response_3.output) == 3
assert response_3.output[0].type == "web_search_call"
assert response_3.output[0].status == "completed"
assert response_3.output[1].type == "web_search_call"
assert response_3.output[1].status == "completed"
assert response_3.output[2].type == "message"
assert response_3.output[2].status == "completed"
assert response_3.output[2].role == "assistant"
# Verify we have a valid max_tool_calls field
assert response_3.max_tool_calls == max_tool_calls[1]

View file

@ -334,7 +334,13 @@ def require_server(llama_stack_client):
@pytest.fixture(scope="session")
def openai_client(llama_stack_client, require_server):
base_url = f"{llama_stack_client.base_url}/v1"
return OpenAI(base_url=base_url, api_key="fake")
client = OpenAI(base_url=base_url, api_key="fake", max_retries=0, timeout=30.0)
yield client
# Cleanup: close HTTP connections
try:
client.close()
except Exception:
pass
@pytest.fixture(params=["openai_client", "client_with_models"])

View file

@ -54,6 +54,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
# {"error":{"message":"Unknown request URL: GET /openai/v1/completions. Please check the URL for typos,
# or see the docs at https://console.groq.com/docs/","type":"invalid_request_error","code":"unknown_url"}}
"remote::groq",
"remote::oci",
"remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404
"remote::anthropic", # at least claude-3-{5,7}-{haiku,sonnet}-* / claude-{sonnet,opus}-4-* are not supported
"remote::azure", # {'error': {'code': 'OperationNotSupported', 'message': 'The completion operation

View file

@ -138,6 +138,7 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
"remote::runpod",
"remote::sambanova",
"remote::tgi",
"remote::oci",
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")

View file

@ -2,6 +2,10 @@
This directory contains recorded inference API responses used for deterministic testing without requiring live API access.
For more information, see the
[docs](https://llamastack.github.io/docs/contributing/testing/record-replay).
This README provides more technical information.
## Structure
- `responses/` - JSON files containing request/response pairs for inference operations

View file

@ -115,7 +115,15 @@ def openai_client(base_url, api_key, provider):
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
return client
return OpenAI(
client = OpenAI(
base_url=base_url,
api_key=api_key,
max_retries=0,
timeout=30.0,
)
yield client
# Cleanup: close HTTP connections
try:
client.close()
except Exception:
pass

View file

@ -65,8 +65,14 @@ class TestConversationResponses:
conversation_items = openai_client.conversations.items.list(conversation.id)
assert len(conversation_items.data) >= 4 # 2 user + 2 assistant messages
@pytest.mark.timeout(60, method="thread")
def test_conversation_context_loading(self, openai_client, text_model_id):
"""Test that conversation context is properly loaded for responses."""
"""Test that conversation context is properly loaded for responses.
Note: 60s timeout added due to CI-specific deadlock in pytest/OpenAI client/httpx
after running 25+ tests. Hangs before first HTTP request is made. Works fine locally.
Investigation needed: connection pool exhaustion or event loop state issue.
"""
conversation = openai_client.conversations.create(
items=[
{"type": "message", "role": "user", "content": "My name is Alice. I like to eat apples."},

View file

@ -11,6 +11,7 @@ import pytest
from llama_stack_client import BadRequestError
from openai import BadRequestError as OpenAIBadRequestError
from llama_stack.apis.files import ExpiresAfter
from llama_stack.apis.vector_io import Chunk
from llama_stack.core.library_client import LlamaStackAsLibraryClient
from llama_stack.log import get_logger
@ -907,16 +908,16 @@ def test_openai_vector_store_retrieve_file_contents(
)
assert file_contents is not None
assert len(file_contents.content) == 1
content = file_contents.content[0]
assert file_contents.object == "vector_store.file_content.page"
assert len(file_contents.data) == 1
content = file_contents.data[0]
# llama-stack-client returns a model, openai-python is a badboy and returns a dict
if not isinstance(content, dict):
content = content.model_dump()
assert content["type"] == "text"
assert content["text"] == test_content.decode("utf-8")
assert file_contents.filename == file_name
assert file_contents.attributes == attributes
assert file_contents.has_more is False
@vector_provider_wrapper
@ -1483,14 +1484,12 @@ def test_openai_vector_store_file_batch_retrieve_contents(
)
assert file_contents is not None
assert file_contents.filename == file_data[i][0]
assert len(file_contents.content) > 0
assert file_contents.object == "vector_store.file_content.page"
assert len(file_contents.data) > 0
# Verify the content matches what we uploaded
content_text = (
file_contents.content[0].text
if hasattr(file_contents.content[0], "text")
else file_contents.content[0]["text"]
file_contents.data[0].text if hasattr(file_contents.data[0], "text") else file_contents.data[0]["text"]
)
assert file_data[i][1].decode("utf-8") in content_text
@ -1606,3 +1605,97 @@ def test_openai_vector_store_embedding_config_from_metadata(
assert "metadata_config_store" in store_names
assert "consistent_config_store" in store_names
@vector_provider_wrapper
def test_openai_vector_store_file_contents_with_extra_query(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test that vector store file contents endpoint supports extra_query parameter."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
compat_client = compat_client_with_empty_stores
# Create a vector store
vector_store = compat_client.vector_stores.create(
name="test_extra_query_store",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
# Create and attach a file
test_content = b"This is test content for extra_query validation."
with BytesIO(test_content) as file_buffer:
file_buffer.name = "test_extra_query.txt"
file = compat_client.files.create(
file=file_buffer,
purpose="assistants",
expires_after=ExpiresAfter(anchor="created_at", seconds=86400),
)
file_attach_response = compat_client.vector_stores.files.create(
vector_store_id=vector_store.id,
file_id=file.id,
extra_body={"embedding_model": embedding_model_id},
)
assert file_attach_response.status == "completed"
# Wait for processing
time.sleep(2)
# Test that extra_query parameter is accepted and processed
content_with_extra_query = compat_client.vector_stores.files.content(
vector_store_id=vector_store.id,
file_id=file.id,
extra_query={"include_embeddings": True, "include_metadata": True},
)
# Test without extra_query for comparison
content_without_extra_query = compat_client.vector_stores.files.content(
vector_store_id=vector_store.id,
file_id=file.id,
)
# Validate that both calls succeed
assert content_with_extra_query is not None
assert content_without_extra_query is not None
assert len(content_with_extra_query.data) > 0
assert len(content_without_extra_query.data) > 0
# Validate that extra_query parameter is processed correctly
# Both should have the embedding/metadata fields available (may be None based on flags)
first_chunk_with_flags = content_with_extra_query.data[0]
first_chunk_without_flags = content_without_extra_query.data[0]
# The key validation: extra_query fields are present in the response
# Handle both dict and object responses (different clients may return different formats)
def has_field(obj, field):
if isinstance(obj, dict):
return field in obj
else:
return hasattr(obj, field)
# Validate that all expected fields are present in both responses
expected_fields = ["embedding", "chunk_metadata", "metadata", "text"]
for field in expected_fields:
assert has_field(first_chunk_with_flags, field), f"Field '{field}' missing from response with extra_query"
assert has_field(first_chunk_without_flags, field), f"Field '{field}' missing from response without extra_query"
# Validate content is the same
def get_field(obj, field):
if isinstance(obj, dict):
return obj[field]
else:
return getattr(obj, field)
assert get_field(first_chunk_with_flags, "text") == test_content.decode("utf-8")
assert get_field(first_chunk_without_flags, "text") == test_content.decode("utf-8")
with_flags_embedding = get_field(first_chunk_with_flags, "embedding")
without_flags_embedding = get_field(first_chunk_without_flags, "embedding")
# Validate that embeddings are included when requested and excluded when not requested
assert with_flags_embedding is not None, "Embeddings should be included when include_embeddings=True"
assert len(with_flags_embedding) > 0, "Embedding should be a non-empty list"
assert without_flags_embedding is None, "Embeddings should not be included when include_embeddings=False"

View file

@ -55,3 +55,65 @@ async def test_create_vector_stores_multiple_providers_missing_provider_id_error
with pytest.raises(ValueError, match="Multiple vector_io providers available"):
await router.openai_create_vector_store(request)
async def test_update_vector_store_provider_id_change_fails():
"""Test that updating a vector store with a different provider_id fails with clear error."""
mock_routing_table = Mock()
# Mock an existing vector store with provider_id "faiss"
mock_existing_store = Mock()
mock_existing_store.provider_id = "inline::faiss"
mock_existing_store.identifier = "vs_123"
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store)
mock_routing_table.get_provider_impl = AsyncMock(
return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
)
router = VectorIORouter(mock_routing_table)
# Try to update with different provider_id in metadata - this should fail
with pytest.raises(ValueError, match="provider_id cannot be changed after vector store creation"):
await router.openai_update_vector_store(
vector_store_id="vs_123",
name="updated_name",
metadata={"provider_id": "inline::sqlite"}, # Different provider_id
)
# Verify the existing store was looked up to check provider_id
mock_routing_table.get_object_by_identifier.assert_called_once_with("vector_store", "vs_123")
# Provider should not be called since validation failed
mock_routing_table.get_provider_impl.assert_not_called()
async def test_update_vector_store_same_provider_id_succeeds():
"""Test that updating a vector store with the same provider_id succeeds."""
mock_routing_table = Mock()
# Mock an existing vector store with provider_id "faiss"
mock_existing_store = Mock()
mock_existing_store.provider_id = "inline::faiss"
mock_existing_store.identifier = "vs_123"
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store)
mock_routing_table.get_provider_impl = AsyncMock(
return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
)
router = VectorIORouter(mock_routing_table)
# Update with same provider_id should succeed
await router.openai_update_vector_store(
vector_store_id="vs_123",
name="updated_name",
metadata={"provider_id": "inline::faiss"}, # Same provider_id
)
# Verify the provider update method was called
mock_routing_table.get_provider_impl.assert_called_once_with("vs_123")
provider = await mock_routing_table.get_provider_impl("vs_123")
provider.openai_update_vector_store.assert_called_once_with(
vector_store_id="vs_123", name="updated_name", expires_after=None, metadata={"provider_id": "inline::faiss"}
)

View file

@ -1,303 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionMessage,
StopReason,
SystemMessage,
SystemMessageBehavior,
ToolCall,
ToolConfig,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
chat_completion_request_to_prompt,
interleaved_content_as_str,
)
MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct"
async def test_system_default():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
async def test_system_builtin_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
async def test_system_custom_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
async def test_system_custom_and_builtin():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
async def test_completion_message_encoding():
request = ChatCompletionRequest(
model=MODEL3_2,
messages=[
UserMessage(content="hello"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
tool_name="custom1",
arguments='{"param1": "value1"}', # arguments must be a JSON string
call_id="123",
)
],
),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '[custom1(param1="value1")]' in prompt
request.model = MODEL
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
async def test_user_provided_system_message():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert messages[-1].content == content
async def test_replace_system_message_behavior_builtin_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
# function description is present in the system prompt
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import Mock
import pytest
from llama_stack.providers.inline.inference.meta_reference.model_parallel import (
ModelRunner,
)
class TestModelRunner:
"""Test ModelRunner task dispatching for model-parallel inference."""
def test_chat_completion_task_dispatch(self):
"""Verify ModelRunner correctly dispatches chat_completion tasks."""
# Create a mock generator
mock_generator = Mock()
mock_generator.chat_completion = Mock(return_value=iter([]))
runner = ModelRunner(mock_generator)
# Create a chat_completion task
fake_params = {"model": "test"}
fake_messages = [{"role": "user", "content": "test"}]
task = ("chat_completion", [fake_params, fake_messages])
# Execute task
runner(task)
# Verify chat_completion was called with correct arguments
mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages)
def test_invalid_task_type_raises_error(self):
"""Verify ModelRunner rejects invalid task types."""
mock_generator = Mock()
runner = ModelRunner(mock_generator)
with pytest.raises(ValueError, match="Unexpected task type"):
runner(("invalid_task", []))

View file

@ -10,11 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.apis.inference import CompletionMessage, UserMessage
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
@ -136,11 +138,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
OpenAIUserMessageParam(content="Hello, how are you?"),
OpenAIAssistantMessageParam(
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]
@ -191,13 +191,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
# Mock Guardrails API response
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
OpenAIUserMessageParam(content="Hello, how are you?"),
OpenAIAssistantMessageParam(
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]
@ -243,7 +240,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
adapter.shield_store.get_shield.return_value = None
messages = [
UserMessage(role="user", content="Hello, how are you?"),
OpenAIUserMessageParam(content="Hello, how are you?"),
]
with pytest.raises(ValueError):
@ -274,11 +271,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
# Running the shield should raise an exception
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
OpenAIUserMessageParam(content="Hello, how are you?"),
OpenAIAssistantMessageParam(
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]

View file

@ -1,220 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from pydantic import ValidationError
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference import (
CompletionMessage,
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
SystemMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
convert_message_to_openai_dict_new,
openai_messages_to_messages,
)
async def test_convert_message_to_openai_dict():
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
assert await convert_message_to_openai_dict(message) == {
"role": "user",
"content": [{"type": "text", "text": "Hello, world!"}],
}
# Test convert_message_to_openai_dict with a tool call
async def test_convert_message_to_openai_dict_with_tool_call():
message = CompletionMessage(
content="",
tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')],
stop_reason=StopReason.end_of_turn,
)
openai_dict = await convert_message_to_openai_dict(message)
assert openai_dict == {
"role": "assistant",
"content": [{"type": "text", "text": ""}],
"tool_calls": [
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
],
}
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
message = CompletionMessage(
content="",
tool_calls=[
ToolCall(
call_id="123",
tool_name=BuiltinTool.brave_search,
arguments='{"foo": "bar"}',
)
],
stop_reason=StopReason.end_of_turn,
)
openai_dict = await convert_message_to_openai_dict(message)
assert openai_dict == {
"role": "assistant",
"content": [{"type": "text", "text": ""}],
"tool_calls": [
{"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}}
],
}
async def test_openai_messages_to_messages_with_content_str():
openai_messages = [
OpenAISystemMessageParam(content="system message"),
OpenAIUserMessageParam(content="user message"),
OpenAIAssistantMessageParam(content="assistant message"),
]
llama_messages = openai_messages_to_messages(openai_messages)
assert len(llama_messages) == 3
assert isinstance(llama_messages[0], SystemMessage)
assert isinstance(llama_messages[1], UserMessage)
assert isinstance(llama_messages[2], CompletionMessage)
assert llama_messages[0].content == "system message"
assert llama_messages[1].content == "user message"
assert llama_messages[2].content == "assistant message"
async def test_openai_messages_to_messages_with_content_list():
openai_messages = [
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]),
OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]),
]
llama_messages = openai_messages_to_messages(openai_messages)
assert len(llama_messages) == 3
assert isinstance(llama_messages[0], SystemMessage)
assert isinstance(llama_messages[1], UserMessage)
assert isinstance(llama_messages[2], CompletionMessage)
assert llama_messages[0].content[0].text == "system message"
assert llama_messages[1].content[0].text == "user message"
assert llama_messages[2].content[0].text == "assistant message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_string(message_class, kwargs):
"""Test that messages accept string text content."""
msg = message_class(content="Test message", **kwargs)
assert msg.content == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_list(message_class, kwargs):
"""Test that messages accept list of text content parts."""
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
msg = message_class(content=content_list, **kwargs)
assert len(msg.content) == 1
assert msg.content[0].text == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_rejects_images(message_class, kwargs):
"""Test that system, assistant, developer, and tool messages reject image content."""
with pytest.raises(ValidationError):
message_class(
content=[
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
],
**kwargs,
)
def test_user_message_accepts_images():
"""Test that user messages accept image content (unlike other message types)."""
# List with images should work
msg = OpenAIUserMessageParam(
content=[
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
]
)
assert len(msg.content) == 2
assert msg.content[0].text == "Describe this image:"
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
async def test_convert_message_to_openai_dict_new_user_message():
"""Test convert_message_to_openai_dict_new with UserMessage."""
message = UserMessage(content="Hello, world!", role="user")
result = await convert_message_to_openai_dict_new(message)
assert result["role"] == "user"
assert result["content"] == "Hello, world!"
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
message = CompletionMessage(
content="I'll help you find the weather.",
tool_calls=[
ToolCall(
call_id="call_123",
tool_name="get_weather",
arguments='{"city": "Sligo"}',
)
],
stop_reason=StopReason.end_of_turn,
)
result = await convert_message_to_openai_dict_new(message)
# This would have failed with "Cannot instantiate typing.Union" before the fix
assert result["role"] == "assistant"
assert result["content"] == "I'll help you find the weather."
assert "tool_calls" in result
assert result["tool_calls"] is not None
assert len(result["tool_calls"]) == 1
tool_call = result["tool_calls"][0]
assert tool_call.id == "call_123"
assert tool_call.type == "function"
assert tool_call.function.name == "get_weather"
assert tool_call.function.arguments == '{"city": "Sligo"}'

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.models.llama.datatypes import RawTextItem
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_openai_message_to_raw_message,
)
class TestConvertOpenAIMessageToRawMessage:
"""Test conversion of OpenAI message types to RawMessage format."""
async def test_user_message_conversion(self):
msg = OpenAIUserMessageParam(role="user", content="Hello world")
raw_msg = await convert_openai_message_to_raw_message(msg)
assert raw_msg.role == "user"
assert isinstance(raw_msg.content, RawTextItem)
assert raw_msg.content.text == "Hello world"
async def test_assistant_message_conversion(self):
msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!")
raw_msg = await convert_openai_message_to_raw_message(msg)
assert raw_msg.role == "assistant"
assert isinstance(raw_msg.content, RawTextItem)
assert raw_msg.content.text == "Hi there!"
assert raw_msg.tool_calls == []

View file

@ -104,12 +104,18 @@ async def test_paginated_response_url_setting():
route_handler = create_dynamic_typed_route(mock_api_method, "get", "/test/route")
# Mock minimal request
# Mock minimal request with proper state object
request = MagicMock()
request.scope = {"user_attributes": {}, "principal": ""}
request.headers = {}
request.body = AsyncMock(return_value=b"")
# Create a simple state object without auto-generating attributes
class MockState:
pass
request.state = MockState()
result = await route_handler(request)
assert isinstance(result, PaginatedResponse)

12
uv.lock generated
View file

@ -1,5 +1,5 @@
version = 1
revision = 3
revision = 2
requires-python = ">=3.12"
resolution-markers = [
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
@ -2166,7 +2166,7 @@ test = [
{ name = "milvus-lite", specifier = ">=2.5.0" },
{ name = "psycopg2-binary", specifier = ">=2.9.0" },
{ name = "pymilvus", specifier = ">=2.6.1" },
{ name = "pypdf" },
{ name = "pypdf", specifier = ">=6.1.3" },
{ name = "qdrant-client" },
{ name = "requests" },
{ name = "sqlalchemy" },
@ -2219,7 +2219,7 @@ unit = [
{ name = "moto", extras = ["s3"], specifier = ">=5.1.10" },
{ name = "ollama" },
{ name = "psycopg2-binary", specifier = ">=2.9.0" },
{ name = "pypdf" },
{ name = "pypdf", specifier = ">=6.1.3" },
{ name = "sqlalchemy" },
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
{ name = "sqlite-vec" },
@ -3973,11 +3973,11 @@ wheels = [
[[package]]
name = "pypdf"
version = "5.9.0"
version = "6.2.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/89/3a/584b97a228950ed85aec97c811c68473d9b8d149e6a8c155668287cf1a28/pypdf-5.9.0.tar.gz", hash = "sha256:30f67a614d558e495e1fbb157ba58c1de91ffc1718f5e0dfeb82a029233890a1", size = 5035118, upload-time = "2025-07-27T14:04:52.364Z" }
sdist = { url = "https://files.pythonhosted.org/packages/4e/2b/8795ec0378384000b0a37a2b5e6d67fa3d84802945aa2c612a78a784d7d4/pypdf-6.2.0.tar.gz", hash = "sha256:46b4d8495d68ae9c818e7964853cd9984e6a04c19fe7112760195395992dce48", size = 5272001, upload-time = "2025-11-09T11:10:41.911Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/48/d9/6cff57c80a6963e7dd183bf09e9f21604a77716644b1e580e97b259f7612/pypdf-5.9.0-py3-none-any.whl", hash = "sha256:be10a4c54202f46d9daceaa8788be07aa8cd5ea8c25c529c50dd509206382c35", size = 313193, upload-time = "2025-07-27T14:04:50.53Z" },
{ url = "https://files.pythonhosted.org/packages/de/ba/743ddcaf1a8fb439342399645921e2cf2c600464cba5531a11f1cc0822b6/pypdf-6.2.0-py3-none-any.whl", hash = "sha256:4c0f3e62677217a777ab79abe22bf1285442d70efabf552f61c7a03b6f5c569f", size = 326592, upload-time = "2025-11-09T11:10:39.941Z" },
]
[[package]]