Merge branch 'main' into feat/add-dana-agent-provider-stub

This commit is contained in:
Zooey Nguyen 2025-11-10 16:36:05 -08:00 committed by GitHub
commit 3f85df3da2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
62 changed files with 3463 additions and 3817 deletions

View file

@ -963,7 +963,7 @@ paths:
Optional filter to control which routes are returned. Can be an API level Optional filter to control which routes are returned. Can be an API level
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
or 'deprecated' to show deprecated routes across all levels. If not specified, or 'deprecated' to show deprecated routes across all levels. If not specified,
returns only non-deprecated v1 routes. returns all non-deprecated routes.
required: false required: false
schema: schema:
type: string type: string
@ -998,39 +998,6 @@ paths:
description: List models using the OpenAI API. description: List models using the OpenAI API.
parameters: [] parameters: []
deprecated: false deprecated: false
post:
responses:
'200':
description: A Model.
content:
application/json:
schema:
$ref: '#/components/schemas/Model'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Models
summary: Register model.
description: >-
Register model.
Register a model.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterModelRequest'
required: true
deprecated: false
/v1/models/{model_id}: /v1/models/{model_id}:
get: get:
responses: responses:
@ -1065,36 +1032,6 @@ paths:
schema: schema:
type: string type: string
deprecated: false deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Models
summary: Unregister model.
description: >-
Unregister model.
Unregister a model.
parameters:
- name: model_id
in: path
description: >-
The identifier of the model to unregister.
required: true
schema:
type: string
deprecated: false
/v1/moderations: /v1/moderations:
post: post:
responses: responses:
@ -1725,32 +1662,6 @@ paths:
description: List all scoring functions. description: List all scoring functions.
parameters: [] parameters: []
deprecated: false deprecated: false
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ScoringFunctions
summary: Register a scoring function.
description: Register a scoring function.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterScoringFunctionRequest'
required: true
deprecated: false
/v1/scoring-functions/{scoring_fn_id}: /v1/scoring-functions/{scoring_fn_id}:
get: get:
responses: responses:
@ -1782,33 +1693,6 @@ paths:
schema: schema:
type: string type: string
deprecated: false deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ScoringFunctions
summary: Unregister a scoring function.
description: Unregister a scoring function.
parameters:
- name: scoring_fn_id
in: path
description: >-
The ID of the scoring function to unregister.
required: true
schema:
type: string
deprecated: false
/v1/scoring/score: /v1/scoring/score:
post: post:
responses: responses:
@ -1897,36 +1781,6 @@ paths:
description: List all shields. description: List all shields.
parameters: [] parameters: []
deprecated: false deprecated: false
post:
responses:
'200':
description: A Shield.
content:
application/json:
schema:
$ref: '#/components/schemas/Shield'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
summary: Register a shield.
description: Register a shield.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterShieldRequest'
required: true
deprecated: false
/v1/shields/{identifier}: /v1/shields/{identifier}:
get: get:
responses: responses:
@ -1958,33 +1812,6 @@ paths:
schema: schema:
type: string type: string
deprecated: false deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
summary: Unregister a shield.
description: Unregister a shield.
parameters:
- name: identifier
in: path
description: >-
The identifier of the shield to unregister.
required: true
schema:
type: string
deprecated: false
/v1/tool-runtime/invoke: /v1/tool-runtime/invoke:
post: post:
responses: responses:
@ -2080,32 +1907,6 @@ paths:
description: List tool groups with optional provider. description: List tool groups with optional provider.
parameters: [] parameters: []
deprecated: false deprecated: false
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolGroups
summary: Register a tool group.
description: Register a tool group.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterToolGroupRequest'
required: true
deprecated: false
/v1/toolgroups/{toolgroup_id}: /v1/toolgroups/{toolgroup_id}:
get: get:
responses: responses:
@ -2137,32 +1938,6 @@ paths:
schema: schema:
type: string type: string
deprecated: false deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolGroups
summary: Unregister a tool group.
description: Unregister a tool group.
parameters:
- name: toolgroup_id
in: path
description: The ID of the tool group to unregister.
required: true
schema:
type: string
deprecated: false
/v1/tools: /v1/tools:
get: get:
responses: responses:
@ -2916,11 +2691,11 @@ paths:
responses: responses:
'200': '200':
description: >- description: >-
A list of InterleavedContent representing the file contents. A VectorStoreFileContentResponse representing the file contents.
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/VectorStoreFileContentsResponse' $ref: '#/components/schemas/VectorStoreFileContentResponse'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -3171,7 +2946,7 @@ paths:
schema: schema:
$ref: '#/components/schemas/RegisterDatasetRequest' $ref: '#/components/schemas/RegisterDatasetRequest'
required: true required: true
deprecated: false deprecated: true
/v1beta/datasets/{dataset_id}: /v1beta/datasets/{dataset_id}:
get: get:
responses: responses:
@ -3228,7 +3003,7 @@ paths:
required: true required: true
schema: schema:
type: string type: string
deprecated: false deprecated: true
/v1alpha/eval/benchmarks: /v1alpha/eval/benchmarks:
get: get:
responses: responses:
@ -3279,7 +3054,7 @@ paths:
schema: schema:
$ref: '#/components/schemas/RegisterBenchmarkRequest' $ref: '#/components/schemas/RegisterBenchmarkRequest'
required: true required: true
deprecated: false deprecated: true
/v1alpha/eval/benchmarks/{benchmark_id}: /v1alpha/eval/benchmarks/{benchmark_id}:
get: get:
responses: responses:
@ -3336,7 +3111,7 @@ paths:
required: true required: true
schema: schema:
type: string type: string
deprecated: false deprecated: true
/v1alpha/eval/benchmarks/{benchmark_id}/evaluations: /v1alpha/eval/benchmarks/{benchmark_id}/evaluations:
post: post:
responses: responses:
@ -6280,46 +6055,6 @@ components:
required: required:
- data - data
title: OpenAIListModelsResponse title: OpenAIListModelsResponse
ModelType:
type: string
enum:
- llm
- embedding
- rerank
title: ModelType
description: >-
Enumeration of supported model types in Llama Stack.
RegisterModelRequest:
type: object
properties:
model_id:
type: string
description: The identifier of the model to register.
provider_model_id:
type: string
description: >-
The identifier of the model in the provider.
provider_id:
type: string
description: The identifier of the provider.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Any additional metadata for this model.
model_type:
$ref: '#/components/schemas/ModelType'
description: The type of model to register.
additionalProperties: false
required:
- model_id
title: RegisterModelRequest
Model: Model:
type: object type: object
properties: properties:
@ -6377,6 +6112,15 @@ components:
title: Model title: Model
description: >- description: >-
A model resource representing an AI model registered in Llama Stack. A model resource representing an AI model registered in Llama Stack.
ModelType:
type: string
enum:
- llm
- embedding
- rerank
title: ModelType
description: >-
Enumeration of supported model types in Llama Stack.
RunModerationRequest: RunModerationRequest:
type: object type: object
properties: properties:
@ -6882,6 +6626,11 @@ components:
type: string type: string
description: >- description: >-
(Optional) System message inserted into the model's context (Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
input: input:
type: array type: array
items: items:
@ -7240,6 +6989,11 @@ components:
(Optional) Additional fields to include in the response. (Optional) Additional fields to include in the response.
max_infer_iters: max_infer_iters:
type: integer type: integer
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response.
additionalProperties: false additionalProperties: false
required: required:
- input - input
@ -7321,6 +7075,11 @@ components:
type: string type: string
description: >- description: >-
(Optional) System message inserted into the model's context (Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
additionalProperties: false additionalProperties: false
required: required:
- created_at - created_at
@ -9115,61 +8874,6 @@ components:
required: required:
- data - data
title: ListScoringFunctionsResponse title: ListScoringFunctionsResponse
ParamType:
oneOf:
- $ref: '#/components/schemas/StringType'
- $ref: '#/components/schemas/NumberType'
- $ref: '#/components/schemas/BooleanType'
- $ref: '#/components/schemas/ArrayType'
- $ref: '#/components/schemas/ObjectType'
- $ref: '#/components/schemas/JsonType'
- $ref: '#/components/schemas/UnionType'
- $ref: '#/components/schemas/ChatCompletionInputType'
- $ref: '#/components/schemas/CompletionInputType'
discriminator:
propertyName: type
mapping:
string: '#/components/schemas/StringType'
number: '#/components/schemas/NumberType'
boolean: '#/components/schemas/BooleanType'
array: '#/components/schemas/ArrayType'
object: '#/components/schemas/ObjectType'
json: '#/components/schemas/JsonType'
union: '#/components/schemas/UnionType'
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
completion_input: '#/components/schemas/CompletionInputType'
RegisterScoringFunctionRequest:
type: object
properties:
scoring_fn_id:
type: string
description: >-
The ID of the scoring function to register.
description:
type: string
description: The description of the scoring function.
return_type:
$ref: '#/components/schemas/ParamType'
description: The return type of the scoring function.
provider_scoring_fn_id:
type: string
description: >-
The ID of the provider scoring function to use for the scoring function.
provider_id:
type: string
description: >-
The ID of the provider to use for the scoring function.
params:
$ref: '#/components/schemas/ScoringFnParams'
description: >-
The parameters for the scoring function for benchmark eval, these can
be overridden for app eval.
additionalProperties: false
required:
- scoring_fn_id
- description
- return_type
title: RegisterScoringFunctionRequest
ScoreRequest: ScoreRequest:
type: object type: object
properties: properties:
@ -9345,35 +9049,6 @@ components:
required: required:
- data - data
title: ListShieldsResponse title: ListShieldsResponse
RegisterShieldRequest:
type: object
properties:
shield_id:
type: string
description: >-
The identifier of the shield to register.
provider_shield_id:
type: string
description: >-
The identifier of the shield in the provider.
provider_id:
type: string
description: The identifier of the provider.
params:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The parameters of the shield.
additionalProperties: false
required:
- shield_id
title: RegisterShieldRequest
InvokeToolRequest: InvokeToolRequest:
type: object type: object
properties: properties:
@ -9634,37 +9309,6 @@ components:
title: ListToolGroupsResponse title: ListToolGroupsResponse
description: >- description: >-
Response containing a list of tool groups. Response containing a list of tool groups.
RegisterToolGroupRequest:
type: object
properties:
toolgroup_id:
type: string
description: The ID of the tool group to register.
provider_id:
type: string
description: >-
The ID of the provider to use for the tool group.
mcp_endpoint:
$ref: '#/components/schemas/URL'
description: >-
The MCP endpoint to use for the tool group.
args:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
A dictionary of arguments to pass to the tool group.
additionalProperties: false
required:
- toolgroup_id
- provider_id
title: RegisterToolGroupRequest
Chunk: Chunk:
type: object type: object
properties: properties:
@ -10465,41 +10109,35 @@ components:
title: VectorStoreContent title: VectorStoreContent
description: >- description: >-
Content item from a vector store file or search result. Content item from a vector store file or search result.
VectorStoreFileContentsResponse: VectorStoreFileContentResponse:
type: object type: object
properties: properties:
file_id: object:
type: string type: string
description: Unique identifier for the file const: vector_store.file_content.page
filename: default: vector_store.file_content.page
type: string
description: Name of the file
attributes:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >- description: >-
Key-value attributes associated with the file The object type, which is always `vector_store.file_content.page`
content: data:
type: array type: array
items: items:
$ref: '#/components/schemas/VectorStoreContent' $ref: '#/components/schemas/VectorStoreContent'
description: List of content items from the file description: Parsed content of the file
has_more:
type: boolean
description: >-
Indicates if there are more content pages to fetch
next_page:
type: string
description: The token for the next page, if any
additionalProperties: false additionalProperties: false
required: required:
- file_id - object
- filename - data
- attributes - has_more
- content title: VectorStoreFileContentResponse
title: VectorStoreFileContentsResponse
description: >- description: >-
Response from retrieving the contents of a vector store file. Represents the parsed content of a vector store file.
OpenaiSearchVectorStoreRequest: OpenaiSearchVectorStoreRequest:
type: object type: object
properties: properties:
@ -10816,68 +10454,6 @@ components:
- data - data
title: ListDatasetsResponse title: ListDatasetsResponse
description: Response from listing datasets. description: Response from listing datasets.
DataSource:
oneOf:
- $ref: '#/components/schemas/URIDataSource'
- $ref: '#/components/schemas/RowsDataSource'
discriminator:
propertyName: type
mapping:
uri: '#/components/schemas/URIDataSource'
rows: '#/components/schemas/RowsDataSource'
RegisterDatasetRequest:
type: object
properties:
purpose:
type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
description: >-
The purpose of the dataset. One of: - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. {
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
contains a question column and an answer column for evaluation. { "question":
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
The dataset contains a messages column with list of messages and an answer
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
Doe. How can I help you today?"}, {"role": "user", "content": "What's
my name?"}, ], "answer": "John Doe" }
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Ensure that the data source schema is
compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The metadata for the dataset. - E.g. {"description": "My dataset"}.
dataset_id:
type: string
description: >-
The ID of the dataset. If not provided, an ID will be generated.
additionalProperties: false
required:
- purpose
- source
title: RegisterDatasetRequest
Benchmark: Benchmark:
type: object type: object
properties: properties:
@ -10945,47 +10521,6 @@ components:
required: required:
- data - data
title: ListBenchmarksResponse title: ListBenchmarksResponse
RegisterBenchmarkRequest:
type: object
properties:
benchmark_id:
type: string
description: The ID of the benchmark to register.
dataset_id:
type: string
description: >-
The ID of the dataset to use for the benchmark.
scoring_functions:
type: array
items:
type: string
description: >-
The scoring functions to use for the benchmark.
provider_benchmark_id:
type: string
description: >-
The ID of the provider benchmark to use for the benchmark.
provider_id:
type: string
description: >-
The ID of the provider to use for the benchmark.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The metadata to use for the benchmark.
additionalProperties: false
required:
- benchmark_id
- dataset_id
- scoring_functions
title: RegisterBenchmarkRequest
BenchmarkConfig: BenchmarkConfig:
type: object type: object
properties: properties:
@ -11847,6 +11382,109 @@ components:
- hyperparam_search_config - hyperparam_search_config
- logger_config - logger_config
title: SupervisedFineTuneRequest title: SupervisedFineTuneRequest
DataSource:
oneOf:
- $ref: '#/components/schemas/URIDataSource'
- $ref: '#/components/schemas/RowsDataSource'
discriminator:
propertyName: type
mapping:
uri: '#/components/schemas/URIDataSource'
rows: '#/components/schemas/RowsDataSource'
RegisterDatasetRequest:
type: object
properties:
purpose:
type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
description: >-
The purpose of the dataset. One of: - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. {
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
contains a question column and an answer column for evaluation. { "question":
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
The dataset contains a messages column with list of messages and an answer
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
Doe. How can I help you today?"}, {"role": "user", "content": "What's
my name?"}, ], "answer": "John Doe" }
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Ensure that the data source schema is
compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The metadata for the dataset. - E.g. {"description": "My dataset"}.
dataset_id:
type: string
description: >-
The ID of the dataset. If not provided, an ID will be generated.
additionalProperties: false
required:
- purpose
- source
title: RegisterDatasetRequest
RegisterBenchmarkRequest:
type: object
properties:
benchmark_id:
type: string
description: The ID of the benchmark to register.
dataset_id:
type: string
description: >-
The ID of the dataset to use for the benchmark.
scoring_functions:
type: array
items:
type: string
description: >-
The scoring functions to use for the benchmark.
provider_benchmark_id:
type: string
description: >-
The ID of the provider benchmark to use for the benchmark.
provider_id:
type: string
description: >-
The ID of the provider to use for the benchmark.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The metadata to use for the benchmark.
additionalProperties: false
required:
- benchmark_id
- dataset_id
- scoring_functions
title: RegisterBenchmarkRequest
responses: responses:
BadRequest400: BadRequest400:
description: The request was invalid or malformed description: The request was invalid or malformed

View file

@ -10,7 +10,7 @@ import TabItem from '@theme/TabItem';
# Kubernetes Deployment Guide # Kubernetes Deployment Guide
Deploy Llama Stack and vLLM servers in a Kubernetes cluster instead of running them locally. This guide covers both local development with Kind and production deployment on AWS EKS. Deploy Llama Stack and vLLM servers in a Kubernetes cluster instead of running them locally. This guide covers deployment using the Kubernetes operator to manage the Llama Stack server with Kind. The vLLM inference server is deployed manually.
## Prerequisites ## Prerequisites
@ -110,115 +110,176 @@ spec:
EOF EOF
``` ```
### Step 3: Configure Llama Stack ### Step 3: Install Kubernetes Operator
Update your run configuration: Install the Llama Stack Kubernetes operator to manage Llama Stack deployments:
```yaml
providers:
inference:
- provider_id: vllm
provider_type: remote::vllm
config:
url: http://vllm-server.default.svc.cluster.local:8000/v1
max_tokens: 4096
api_token: fake
```
Build container image:
```bash ```bash
tmp_dir=$(mktemp -d) && cat >$tmp_dir/Containerfile.llama-stack-run-k8s <<EOF # Install from the latest main branch
FROM distribution-myenv:dev kubectl apply -f https://raw.githubusercontent.com/llamastack/llama-stack-k8s-operator/main/release/operator.yaml
RUN apt-get update && apt-get install -y git
RUN git clone https://github.com/meta-llama/llama-stack.git /app/llama-stack-source # Or install a specific version (e.g., v0.4.0)
ADD ./vllm-llama-stack-run-k8s.yaml /app/config.yaml # kubectl apply -f https://raw.githubusercontent.com/llamastack/llama-stack-k8s-operator/v0.4.0/release/operator.yaml
EOF
podman build -f $tmp_dir/Containerfile.llama-stack-run-k8s -t llama-stack-run-k8s $tmp_dir
``` ```
### Step 4: Deploy Llama Stack Server Verify the operator is running:
```bash
kubectl get pods -n llama-stack-operator-system
```
For more information about the operator, see the [llama-stack-k8s-operator repository](https://github.com/llamastack/llama-stack-k8s-operator).
### Step 4: Deploy Llama Stack Server using Operator
Create a `LlamaStackDistribution` custom resource to deploy the Llama Stack server. The operator will automatically create the necessary Deployment, Service, and other resources.
You can optionally override the default `run.yaml` using `spec.server.userConfig` with a ConfigMap (see [userConfig spec](https://github.com/llamastack/llama-stack-k8s-operator/blob/main/docs/api-overview.md#userconfigspec)).
```yaml ```yaml
cat <<EOF | kubectl apply -f - cat <<EOF | kubectl apply -f -
apiVersion: v1 apiVersion: llamastack.io/v1alpha1
kind: PersistentVolumeClaim kind: LlamaStackDistribution
metadata: metadata:
name: llama-pvc name: llamastack-vllm
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 1Gi
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: llama-stack-server
spec: spec:
replicas: 1 replicas: 1
selector: server:
matchLabels: distribution:
app.kubernetes.io/name: llama-stack name: starter
template: containerSpec:
metadata: port: 8321
labels: env:
app.kubernetes.io/name: llama-stack - name: VLLM_URL
spec: value: "http://vllm-server.default.svc.cluster.local:8000/v1"
containers: - name: VLLM_MAX_TOKENS
- name: llama-stack value: "4096"
image: localhost/llama-stack-run-k8s:latest - name: VLLM_API_TOKEN
imagePullPolicy: IfNotPresent value: "fake"
command: ["llama", "stack", "run", "/app/config.yaml"] # Optional: override run.yaml from a ConfigMap using userConfig
ports: userConfig:
- containerPort: 5000 configMap:
volumeMounts: name: llama-stack-config
- name: llama-storage storage:
mountPath: /root/.llama size: "20Gi"
volumes: mountPath: "/home/lls/.lls"
- name: llama-storage
persistentVolumeClaim:
claimName: llama-pvc
---
apiVersion: v1
kind: Service
metadata:
name: llama-stack-service
spec:
selector:
app.kubernetes.io/name: llama-stack
ports:
- protocol: TCP
port: 5000
targetPort: 5000
type: ClusterIP
EOF EOF
``` ```
**Configuration Options:**
- `replicas`: Number of Llama Stack server instances to run
- `server.distribution.name`: The distribution to use (e.g., `starter` for the starter distribution). See the [list of supported distributions](https://github.com/llamastack/llama-stack-k8s-operator/blob/main/distributions.json) in the operator repository.
- `server.distribution.image`: (Optional) Custom container image for non-supported distributions. Use this field when deploying a distribution that is not in the supported list. If specified, this takes precedence over `name`.
- `server.containerSpec.port`: Port on which the Llama Stack server listens (default: 8321)
- `server.containerSpec.env`: Environment variables to configure providers:
- `server.userConfig`: (Optional) Override the default `run.yaml` using a ConfigMap. See [userConfig spec](https://github.com/llamastack/llama-stack-k8s-operator/blob/main/docs/api-overview.md#userconfigspec).
- `server.storage.size`: Size of the persistent volume for model and data storage
- `server.storage.mountPath`: Where to mount the storage in the container
**Note:** For a complete list of supported distributions, see [distributions.json](https://github.com/llamastack/llama-stack-k8s-operator/blob/main/distributions.json) in the operator repository. To use a custom or non-supported distribution, set the `server.distribution.image` field with your container image instead of `server.distribution.name`.
The operator automatically creates:
- A Deployment for the Llama Stack server
- A Service to access the server
- A PersistentVolumeClaim for storage
- All necessary RBAC resources
Check the status of your deployment:
```bash
kubectl get llamastackdistribution
kubectl describe llamastackdistribution llamastack-vllm
```
### Step 5: Test Deployment ### Step 5: Test Deployment
Wait for the Llama Stack server pod to be ready:
```bash ```bash
# Port forward and test # Check the status of the LlamaStackDistribution
kubectl port-forward service/llama-stack-service 5000:5000 kubectl get llamastackdistribution llamastack-vllm
llama-stack-client --endpoint http://localhost:5000 inference chat-completion --message "hello, what model are you?"
# Check the pods created by the operator
kubectl get pods -l app.kubernetes.io/name=llama-stack
# Wait for the pod to be ready
kubectl wait --for=condition=ready pod -l app.kubernetes.io/name=llama-stack --timeout=300s
```
Get the service name created by the operator (it typically follows the pattern `<llamastackdistribution-name>-service`):
```bash
# List services to find the service name
kubectl get services | grep llamastack
# Port forward and test (replace SERVICE_NAME with the actual service name)
kubectl port-forward service/llamastack-vllm-service 8321:8321
```
In another terminal, test the deployment:
```bash
llama-stack-client --endpoint http://localhost:8321 inference chat-completion --message "hello, what model are you?"
``` ```
## Troubleshooting ## Troubleshooting
**Check pod status:** ### vLLM Server Issues
**Check vLLM pod status:**
```bash ```bash
kubectl get pods -l app.kubernetes.io/name=vllm kubectl get pods -l app.kubernetes.io/name=vllm
kubectl logs -l app.kubernetes.io/name=vllm kubectl logs -l app.kubernetes.io/name=vllm
``` ```
**Test service connectivity:** **Test vLLM service connectivity:**
```bash ```bash
kubectl run -it --rm debug --image=curlimages/curl --restart=Never -- curl http://vllm-server:8000/v1/models kubectl run -it --rm debug --image=curlimages/curl --restart=Never -- curl http://vllm-server:8000/v1/models
``` ```
### Llama Stack Server Issues
**Check LlamaStackDistribution status:**
```bash
# Get detailed status
kubectl describe llamastackdistribution llamastack-vllm
# Check for events
kubectl get events --sort-by='.lastTimestamp' | grep llamastack-vllm
```
**Check operator-managed pods:**
```bash
# List all pods managed by the operator
kubectl get pods -l app.kubernetes.io/name=llama-stack
# Check pod logs (replace POD_NAME with actual pod name)
kubectl logs -l app.kubernetes.io/name=llama-stack
```
**Check operator status:**
```bash
# Verify the operator is running
kubectl get pods -n llama-stack-operator-system
# Check operator logs if issues persist
kubectl logs -n llama-stack-operator-system -l control-plane=controller-manager
```
**Verify service connectivity:**
```bash
# Get the service endpoint
kubectl get svc llamastack-vllm-service
# Test connectivity from within the cluster
kubectl run -it --rm debug --image=curlimages/curl --restart=Never -- curl http://llamastack-vllm-service:8321/health
```
## Related Resources ## Related Resources
- **[Deployment Overview](/docs/deploying/)** - Overview of deployment options - **[Deployment Overview](/docs/deploying/)** - Overview of deployment options
- **[Distributions](/docs/distributions)** - Understanding Llama Stack distributions - **[Distributions](/docs/distributions)** - Understanding Llama Stack distributions
- **[Configuration](/docs/distributions/configuration)** - Detailed configuration options - **[Configuration](/docs/distributions/configuration)** - Detailed configuration options
- **[LlamaStack Operator](https://github.com/llamastack/llama-stack-k8s-operator)** - Overview of llama-stack kubernetes operator
- **[LlamaStackDistribution](https://github.com/llamastack/llama-stack-k8s-operator/blob/main/docs/api-overview.md)** - API Spec of the llama-stack operator Custom Resource.

View file

@ -0,0 +1,143 @@
---
orphan: true
---
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
# OCI Distribution
The `llamastack/distribution-oci` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| files | `inline::localfs` |
| inference | `remote::oci` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
### Environment Variables
The following environment variables can be configured:
- `OCI_AUTH_TYPE`: OCI authentication type (instance_principal or config_file) (default: `instance_principal`)
- `OCI_REGION`: OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1) (default: ``)
- `OCI_COMPARTMENT_OCID`: OCI compartment ID for the Generative AI service (default: ``)
- `OCI_CONFIG_FILE_PATH`: OCI config file path (required if OCI_AUTH_TYPE is config_file) (default: `~/.oci/config`)
- `OCI_CLI_PROFILE`: OCI CLI profile name to use from config file (default: `DEFAULT`)
## Prerequisites
### Oracle Cloud Infrastructure Setup
Before using the OCI Generative AI distribution, ensure you have:
1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/)
2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy
3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models
4. **Authentication**: Configure authentication using either:
- **Instance Principal** (recommended for cloud-hosted deployments)
- **API Key** (for on-premises or development environments)
### Authentication Methods
#### Instance Principal Authentication (Recommended)
Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments.
Requirements:
- Instance must be running in an Oracle Cloud Infrastructure compartment
- Instance must have appropriate IAM policies to access Generative AI services
#### API Key Authentication
For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file.
### Required IAM Policies
Ensure your OCI user or instance has the following policy statements:
```
Allow group <group_name> to use generative-ai-inference-endpoints in compartment <compartment_name>
Allow group <group_name> to manage generative-ai-inference-endpoints in compartment <compartment_name>
```
## Supported Services
### Inference: OCI Generative AI
Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports:
- **Chat Completions**: Conversational AI with context awareness
- **Text Generation**: Complete prompts and generate text content
#### Available Models
Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models.
### Safety: Llama Guard
For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide:
- Content filtering and moderation
- Policy compliance checking
- Harmful content detection
### Vector Storage: Multiple Options
The distribution supports several vector storage providers:
- **FAISS**: Local in-memory vector search
- **ChromaDB**: Distributed vector database
- **PGVector**: PostgreSQL with vector extensions
### Additional Services
- **Dataset I/O**: Local filesystem and Hugging Face integration
- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities
- **Evaluation**: Meta reference evaluation framework
## Running Llama Stack with OCI
You can run the OCI distribution via Docker or local virtual environment.
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
```bash
OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci
```
### Configuration Examples
#### Using Instance Principal (Recommended for Production)
```bash
export OCI_AUTH_TYPE=instance_principal
export OCI_REGION=us-chicago-1
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..<your-compartment-id>
```
#### Using API Key Authentication (Development)
```bash
export OCI_AUTH_TYPE=config_file
export OCI_CONFIG_FILE_PATH=~/.oci/config
export OCI_CLI_PROFILE=DEFAULT
export OCI_REGION=us-chicago-1
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id
```
## Regional Endpoints
OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit:
https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions
## Troubleshooting
### Common Issues
1. **Authentication Errors**: Verify your OCI credentials and IAM policies
2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region
3. **Permission Denied**: Check compartment permissions and Generative AI service access
4. **Region Unavailable**: Verify the specified region supports Generative AI services
### Getting Help
For additional support:
- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)
- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues)

View file

@ -0,0 +1,41 @@
---
description: |
Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models.
Provider documentation
https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm
sidebar_label: Remote - Oci
title: remote::oci
---
# remote::oci
## Description
Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models.
Provider documentation
https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `oci_auth_type` | `<class 'str'>` | No | instance_principal | OCI authentication type (must be one of: instance_principal, config_file) |
| `oci_region` | `<class 'str'>` | No | us-ashburn-1 | OCI region (e.g., us-ashburn-1) |
| `oci_compartment_id` | `<class 'str'>` | No | | OCI compartment ID for the Generative AI service |
| `oci_config_file_path` | `<class 'str'>` | No | ~/.oci/config | OCI config file path (required if oci_auth_type is config_file) |
| `oci_config_profile` | `<class 'str'>` | No | DEFAULT | OCI config profile (required if oci_auth_type is config_file) |
## Sample Configuration
```yaml
oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal}
oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}
oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT}
oci_region: ${env.OCI_REGION:=us-ashburn-1}
oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=}
```

File diff suppressed because it is too large Load diff

View file

@ -162,7 +162,7 @@ paths:
schema: schema:
$ref: '#/components/schemas/RegisterDatasetRequest' $ref: '#/components/schemas/RegisterDatasetRequest'
required: true required: true
deprecated: false deprecated: true
/v1beta/datasets/{dataset_id}: /v1beta/datasets/{dataset_id}:
get: get:
responses: responses:
@ -219,7 +219,7 @@ paths:
required: true required: true
schema: schema:
type: string type: string
deprecated: false deprecated: true
/v1alpha/eval/benchmarks: /v1alpha/eval/benchmarks:
get: get:
responses: responses:
@ -270,7 +270,7 @@ paths:
schema: schema:
$ref: '#/components/schemas/RegisterBenchmarkRequest' $ref: '#/components/schemas/RegisterBenchmarkRequest'
required: true required: true
deprecated: false deprecated: true
/v1alpha/eval/benchmarks/{benchmark_id}: /v1alpha/eval/benchmarks/{benchmark_id}:
get: get:
responses: responses:
@ -327,7 +327,7 @@ paths:
required: true required: true
schema: schema:
type: string type: string
deprecated: false deprecated: true
/v1alpha/eval/benchmarks/{benchmark_id}/evaluations: /v1alpha/eval/benchmarks/{benchmark_id}/evaluations:
post: post:
responses: responses:
@ -936,68 +936,6 @@ components:
- data - data
title: ListDatasetsResponse title: ListDatasetsResponse
description: Response from listing datasets. description: Response from listing datasets.
DataSource:
oneOf:
- $ref: '#/components/schemas/URIDataSource'
- $ref: '#/components/schemas/RowsDataSource'
discriminator:
propertyName: type
mapping:
uri: '#/components/schemas/URIDataSource'
rows: '#/components/schemas/RowsDataSource'
RegisterDatasetRequest:
type: object
properties:
purpose:
type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
description: >-
The purpose of the dataset. One of: - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. {
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
contains a question column and an answer column for evaluation. { "question":
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
The dataset contains a messages column with list of messages and an answer
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
Doe. How can I help you today?"}, {"role": "user", "content": "What's
my name?"}, ], "answer": "John Doe" }
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Ensure that the data source schema is
compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The metadata for the dataset. - E.g. {"description": "My dataset"}.
dataset_id:
type: string
description: >-
The ID of the dataset. If not provided, an ID will be generated.
additionalProperties: false
required:
- purpose
- source
title: RegisterDatasetRequest
Benchmark: Benchmark:
type: object type: object
properties: properties:
@ -1065,47 +1003,6 @@ components:
required: required:
- data - data
title: ListBenchmarksResponse title: ListBenchmarksResponse
RegisterBenchmarkRequest:
type: object
properties:
benchmark_id:
type: string
description: The ID of the benchmark to register.
dataset_id:
type: string
description: >-
The ID of the dataset to use for the benchmark.
scoring_functions:
type: array
items:
type: string
description: >-
The scoring functions to use for the benchmark.
provider_benchmark_id:
type: string
description: >-
The ID of the provider benchmark to use for the benchmark.
provider_id:
type: string
description: >-
The ID of the provider to use for the benchmark.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The metadata to use for the benchmark.
additionalProperties: false
required:
- benchmark_id
- dataset_id
- scoring_functions
title: RegisterBenchmarkRequest
AggregationFunctionType: AggregationFunctionType:
type: string type: string
enum: enum:
@ -2254,6 +2151,109 @@ components:
- hyperparam_search_config - hyperparam_search_config
- logger_config - logger_config
title: SupervisedFineTuneRequest title: SupervisedFineTuneRequest
DataSource:
oneOf:
- $ref: '#/components/schemas/URIDataSource'
- $ref: '#/components/schemas/RowsDataSource'
discriminator:
propertyName: type
mapping:
uri: '#/components/schemas/URIDataSource'
rows: '#/components/schemas/RowsDataSource'
RegisterDatasetRequest:
type: object
properties:
purpose:
type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
description: >-
The purpose of the dataset. One of: - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. {
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
contains a question column and an answer column for evaluation. { "question":
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
The dataset contains a messages column with list of messages and an answer
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
Doe. How can I help you today?"}, {"role": "user", "content": "What's
my name?"}, ], "answer": "John Doe" }
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Ensure that the data source schema is
compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The metadata for the dataset. - E.g. {"description": "My dataset"}.
dataset_id:
type: string
description: >-
The ID of the dataset. If not provided, an ID will be generated.
additionalProperties: false
required:
- purpose
- source
title: RegisterDatasetRequest
RegisterBenchmarkRequest:
type: object
properties:
benchmark_id:
type: string
description: The ID of the benchmark to register.
dataset_id:
type: string
description: >-
The ID of the dataset to use for the benchmark.
scoring_functions:
type: array
items:
type: string
description: >-
The scoring functions to use for the benchmark.
provider_benchmark_id:
type: string
description: >-
The ID of the provider benchmark to use for the benchmark.
provider_id:
type: string
description: >-
The ID of the provider to use for the benchmark.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The metadata to use for the benchmark.
additionalProperties: false
required:
- benchmark_id
- dataset_id
- scoring_functions
title: RegisterBenchmarkRequest
responses: responses:
BadRequest400: BadRequest400:
description: The request was invalid or malformed description: The request was invalid or malformed

View file

@ -960,7 +960,7 @@ paths:
Optional filter to control which routes are returned. Can be an API level Optional filter to control which routes are returned. Can be an API level
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
or 'deprecated' to show deprecated routes across all levels. If not specified, or 'deprecated' to show deprecated routes across all levels. If not specified,
returns only non-deprecated v1 routes. returns all non-deprecated routes.
required: false required: false
schema: schema:
type: string type: string
@ -995,39 +995,6 @@ paths:
description: List models using the OpenAI API. description: List models using the OpenAI API.
parameters: [] parameters: []
deprecated: false deprecated: false
post:
responses:
'200':
description: A Model.
content:
application/json:
schema:
$ref: '#/components/schemas/Model'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Models
summary: Register model.
description: >-
Register model.
Register a model.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterModelRequest'
required: true
deprecated: false
/v1/models/{model_id}: /v1/models/{model_id}:
get: get:
responses: responses:
@ -1062,36 +1029,6 @@ paths:
schema: schema:
type: string type: string
deprecated: false deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Models
summary: Unregister model.
description: >-
Unregister model.
Unregister a model.
parameters:
- name: model_id
in: path
description: >-
The identifier of the model to unregister.
required: true
schema:
type: string
deprecated: false
/v1/moderations: /v1/moderations:
post: post:
responses: responses:
@ -1722,32 +1659,6 @@ paths:
description: List all scoring functions. description: List all scoring functions.
parameters: [] parameters: []
deprecated: false deprecated: false
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ScoringFunctions
summary: Register a scoring function.
description: Register a scoring function.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterScoringFunctionRequest'
required: true
deprecated: false
/v1/scoring-functions/{scoring_fn_id}: /v1/scoring-functions/{scoring_fn_id}:
get: get:
responses: responses:
@ -1779,33 +1690,6 @@ paths:
schema: schema:
type: string type: string
deprecated: false deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ScoringFunctions
summary: Unregister a scoring function.
description: Unregister a scoring function.
parameters:
- name: scoring_fn_id
in: path
description: >-
The ID of the scoring function to unregister.
required: true
schema:
type: string
deprecated: false
/v1/scoring/score: /v1/scoring/score:
post: post:
responses: responses:
@ -1894,36 +1778,6 @@ paths:
description: List all shields. description: List all shields.
parameters: [] parameters: []
deprecated: false deprecated: false
post:
responses:
'200':
description: A Shield.
content:
application/json:
schema:
$ref: '#/components/schemas/Shield'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
summary: Register a shield.
description: Register a shield.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterShieldRequest'
required: true
deprecated: false
/v1/shields/{identifier}: /v1/shields/{identifier}:
get: get:
responses: responses:
@ -1955,33 +1809,6 @@ paths:
schema: schema:
type: string type: string
deprecated: false deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
summary: Unregister a shield.
description: Unregister a shield.
parameters:
- name: identifier
in: path
description: >-
The identifier of the shield to unregister.
required: true
schema:
type: string
deprecated: false
/v1/tool-runtime/invoke: /v1/tool-runtime/invoke:
post: post:
responses: responses:
@ -2077,32 +1904,6 @@ paths:
description: List tool groups with optional provider. description: List tool groups with optional provider.
parameters: [] parameters: []
deprecated: false deprecated: false
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolGroups
summary: Register a tool group.
description: Register a tool group.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterToolGroupRequest'
required: true
deprecated: false
/v1/toolgroups/{toolgroup_id}: /v1/toolgroups/{toolgroup_id}:
get: get:
responses: responses:
@ -2134,32 +1935,6 @@ paths:
schema: schema:
type: string type: string
deprecated: false deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolGroups
summary: Unregister a tool group.
description: Unregister a tool group.
parameters:
- name: toolgroup_id
in: path
description: The ID of the tool group to unregister.
required: true
schema:
type: string
deprecated: false
/v1/tools: /v1/tools:
get: get:
responses: responses:
@ -2913,11 +2688,11 @@ paths:
responses: responses:
'200': '200':
description: >- description: >-
A list of InterleavedContent representing the file contents. A VectorStoreFileContentResponse representing the file contents.
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/VectorStoreFileContentsResponse' $ref: '#/components/schemas/VectorStoreFileContentResponse'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -5564,46 +5339,6 @@ components:
required: required:
- data - data
title: OpenAIListModelsResponse title: OpenAIListModelsResponse
ModelType:
type: string
enum:
- llm
- embedding
- rerank
title: ModelType
description: >-
Enumeration of supported model types in Llama Stack.
RegisterModelRequest:
type: object
properties:
model_id:
type: string
description: The identifier of the model to register.
provider_model_id:
type: string
description: >-
The identifier of the model in the provider.
provider_id:
type: string
description: The identifier of the provider.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Any additional metadata for this model.
model_type:
$ref: '#/components/schemas/ModelType'
description: The type of model to register.
additionalProperties: false
required:
- model_id
title: RegisterModelRequest
Model: Model:
type: object type: object
properties: properties:
@ -5661,6 +5396,15 @@ components:
title: Model title: Model
description: >- description: >-
A model resource representing an AI model registered in Llama Stack. A model resource representing an AI model registered in Llama Stack.
ModelType:
type: string
enum:
- llm
- embedding
- rerank
title: ModelType
description: >-
Enumeration of supported model types in Llama Stack.
RunModerationRequest: RunModerationRequest:
type: object type: object
properties: properties:
@ -6166,6 +5910,11 @@ components:
type: string type: string
description: >- description: >-
(Optional) System message inserted into the model's context (Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
input: input:
type: array type: array
items: items:
@ -6524,6 +6273,11 @@ components:
(Optional) Additional fields to include in the response. (Optional) Additional fields to include in the response.
max_infer_iters: max_infer_iters:
type: integer type: integer
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response.
additionalProperties: false additionalProperties: false
required: required:
- input - input
@ -6605,6 +6359,11 @@ components:
type: string type: string
description: >- description: >-
(Optional) System message inserted into the model's context (Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
additionalProperties: false additionalProperties: false
required: required:
- created_at - created_at
@ -8399,61 +8158,6 @@ components:
required: required:
- data - data
title: ListScoringFunctionsResponse title: ListScoringFunctionsResponse
ParamType:
oneOf:
- $ref: '#/components/schemas/StringType'
- $ref: '#/components/schemas/NumberType'
- $ref: '#/components/schemas/BooleanType'
- $ref: '#/components/schemas/ArrayType'
- $ref: '#/components/schemas/ObjectType'
- $ref: '#/components/schemas/JsonType'
- $ref: '#/components/schemas/UnionType'
- $ref: '#/components/schemas/ChatCompletionInputType'
- $ref: '#/components/schemas/CompletionInputType'
discriminator:
propertyName: type
mapping:
string: '#/components/schemas/StringType'
number: '#/components/schemas/NumberType'
boolean: '#/components/schemas/BooleanType'
array: '#/components/schemas/ArrayType'
object: '#/components/schemas/ObjectType'
json: '#/components/schemas/JsonType'
union: '#/components/schemas/UnionType'
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
completion_input: '#/components/schemas/CompletionInputType'
RegisterScoringFunctionRequest:
type: object
properties:
scoring_fn_id:
type: string
description: >-
The ID of the scoring function to register.
description:
type: string
description: The description of the scoring function.
return_type:
$ref: '#/components/schemas/ParamType'
description: The return type of the scoring function.
provider_scoring_fn_id:
type: string
description: >-
The ID of the provider scoring function to use for the scoring function.
provider_id:
type: string
description: >-
The ID of the provider to use for the scoring function.
params:
$ref: '#/components/schemas/ScoringFnParams'
description: >-
The parameters for the scoring function for benchmark eval, these can
be overridden for app eval.
additionalProperties: false
required:
- scoring_fn_id
- description
- return_type
title: RegisterScoringFunctionRequest
ScoreRequest: ScoreRequest:
type: object type: object
properties: properties:
@ -8629,35 +8333,6 @@ components:
required: required:
- data - data
title: ListShieldsResponse title: ListShieldsResponse
RegisterShieldRequest:
type: object
properties:
shield_id:
type: string
description: >-
The identifier of the shield to register.
provider_shield_id:
type: string
description: >-
The identifier of the shield in the provider.
provider_id:
type: string
description: The identifier of the provider.
params:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The parameters of the shield.
additionalProperties: false
required:
- shield_id
title: RegisterShieldRequest
InvokeToolRequest: InvokeToolRequest:
type: object type: object
properties: properties:
@ -8918,37 +8593,6 @@ components:
title: ListToolGroupsResponse title: ListToolGroupsResponse
description: >- description: >-
Response containing a list of tool groups. Response containing a list of tool groups.
RegisterToolGroupRequest:
type: object
properties:
toolgroup_id:
type: string
description: The ID of the tool group to register.
provider_id:
type: string
description: >-
The ID of the provider to use for the tool group.
mcp_endpoint:
$ref: '#/components/schemas/URL'
description: >-
The MCP endpoint to use for the tool group.
args:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
A dictionary of arguments to pass to the tool group.
additionalProperties: false
required:
- toolgroup_id
- provider_id
title: RegisterToolGroupRequest
Chunk: Chunk:
type: object type: object
properties: properties:
@ -9749,41 +9393,35 @@ components:
title: VectorStoreContent title: VectorStoreContent
description: >- description: >-
Content item from a vector store file or search result. Content item from a vector store file or search result.
VectorStoreFileContentsResponse: VectorStoreFileContentResponse:
type: object type: object
properties: properties:
file_id: object:
type: string type: string
description: Unique identifier for the file const: vector_store.file_content.page
filename: default: vector_store.file_content.page
type: string
description: Name of the file
attributes:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >- description: >-
Key-value attributes associated with the file The object type, which is always `vector_store.file_content.page`
content: data:
type: array type: array
items: items:
$ref: '#/components/schemas/VectorStoreContent' $ref: '#/components/schemas/VectorStoreContent'
description: List of content items from the file description: Parsed content of the file
has_more:
type: boolean
description: >-
Indicates if there are more content pages to fetch
next_page:
type: string
description: The token for the next page, if any
additionalProperties: false additionalProperties: false
required: required:
- file_id - object
- filename - data
- attributes - has_more
- content title: VectorStoreFileContentResponse
title: VectorStoreFileContentsResponse
description: >- description: >-
Response from retrieving the contents of a vector store file. Represents the parsed content of a vector store file.
OpenaiSearchVectorStoreRequest: OpenaiSearchVectorStoreRequest:
type: object type: object
properties: properties:

View file

@ -963,7 +963,7 @@ paths:
Optional filter to control which routes are returned. Can be an API level Optional filter to control which routes are returned. Can be an API level
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
or 'deprecated' to show deprecated routes across all levels. If not specified, or 'deprecated' to show deprecated routes across all levels. If not specified,
returns only non-deprecated v1 routes. returns all non-deprecated routes.
required: false required: false
schema: schema:
type: string type: string
@ -998,39 +998,6 @@ paths:
description: List models using the OpenAI API. description: List models using the OpenAI API.
parameters: [] parameters: []
deprecated: false deprecated: false
post:
responses:
'200':
description: A Model.
content:
application/json:
schema:
$ref: '#/components/schemas/Model'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Models
summary: Register model.
description: >-
Register model.
Register a model.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterModelRequest'
required: true
deprecated: false
/v1/models/{model_id}: /v1/models/{model_id}:
get: get:
responses: responses:
@ -1065,36 +1032,6 @@ paths:
schema: schema:
type: string type: string
deprecated: false deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Models
summary: Unregister model.
description: >-
Unregister model.
Unregister a model.
parameters:
- name: model_id
in: path
description: >-
The identifier of the model to unregister.
required: true
schema:
type: string
deprecated: false
/v1/moderations: /v1/moderations:
post: post:
responses: responses:
@ -1725,32 +1662,6 @@ paths:
description: List all scoring functions. description: List all scoring functions.
parameters: [] parameters: []
deprecated: false deprecated: false
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ScoringFunctions
summary: Register a scoring function.
description: Register a scoring function.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterScoringFunctionRequest'
required: true
deprecated: false
/v1/scoring-functions/{scoring_fn_id}: /v1/scoring-functions/{scoring_fn_id}:
get: get:
responses: responses:
@ -1782,33 +1693,6 @@ paths:
schema: schema:
type: string type: string
deprecated: false deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ScoringFunctions
summary: Unregister a scoring function.
description: Unregister a scoring function.
parameters:
- name: scoring_fn_id
in: path
description: >-
The ID of the scoring function to unregister.
required: true
schema:
type: string
deprecated: false
/v1/scoring/score: /v1/scoring/score:
post: post:
responses: responses:
@ -1897,36 +1781,6 @@ paths:
description: List all shields. description: List all shields.
parameters: [] parameters: []
deprecated: false deprecated: false
post:
responses:
'200':
description: A Shield.
content:
application/json:
schema:
$ref: '#/components/schemas/Shield'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
summary: Register a shield.
description: Register a shield.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterShieldRequest'
required: true
deprecated: false
/v1/shields/{identifier}: /v1/shields/{identifier}:
get: get:
responses: responses:
@ -1958,33 +1812,6 @@ paths:
schema: schema:
type: string type: string
deprecated: false deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
summary: Unregister a shield.
description: Unregister a shield.
parameters:
- name: identifier
in: path
description: >-
The identifier of the shield to unregister.
required: true
schema:
type: string
deprecated: false
/v1/tool-runtime/invoke: /v1/tool-runtime/invoke:
post: post:
responses: responses:
@ -2080,32 +1907,6 @@ paths:
description: List tool groups with optional provider. description: List tool groups with optional provider.
parameters: [] parameters: []
deprecated: false deprecated: false
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolGroups
summary: Register a tool group.
description: Register a tool group.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterToolGroupRequest'
required: true
deprecated: false
/v1/toolgroups/{toolgroup_id}: /v1/toolgroups/{toolgroup_id}:
get: get:
responses: responses:
@ -2137,32 +1938,6 @@ paths:
schema: schema:
type: string type: string
deprecated: false deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolGroups
summary: Unregister a tool group.
description: Unregister a tool group.
parameters:
- name: toolgroup_id
in: path
description: The ID of the tool group to unregister.
required: true
schema:
type: string
deprecated: false
/v1/tools: /v1/tools:
get: get:
responses: responses:
@ -2916,11 +2691,11 @@ paths:
responses: responses:
'200': '200':
description: >- description: >-
A list of InterleavedContent representing the file contents. A VectorStoreFileContentResponse representing the file contents.
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/VectorStoreFileContentsResponse' $ref: '#/components/schemas/VectorStoreFileContentResponse'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -3171,7 +2946,7 @@ paths:
schema: schema:
$ref: '#/components/schemas/RegisterDatasetRequest' $ref: '#/components/schemas/RegisterDatasetRequest'
required: true required: true
deprecated: false deprecated: true
/v1beta/datasets/{dataset_id}: /v1beta/datasets/{dataset_id}:
get: get:
responses: responses:
@ -3228,7 +3003,7 @@ paths:
required: true required: true
schema: schema:
type: string type: string
deprecated: false deprecated: true
/v1alpha/eval/benchmarks: /v1alpha/eval/benchmarks:
get: get:
responses: responses:
@ -3279,7 +3054,7 @@ paths:
schema: schema:
$ref: '#/components/schemas/RegisterBenchmarkRequest' $ref: '#/components/schemas/RegisterBenchmarkRequest'
required: true required: true
deprecated: false deprecated: true
/v1alpha/eval/benchmarks/{benchmark_id}: /v1alpha/eval/benchmarks/{benchmark_id}:
get: get:
responses: responses:
@ -3336,7 +3111,7 @@ paths:
required: true required: true
schema: schema:
type: string type: string
deprecated: false deprecated: true
/v1alpha/eval/benchmarks/{benchmark_id}/evaluations: /v1alpha/eval/benchmarks/{benchmark_id}/evaluations:
post: post:
responses: responses:
@ -6280,46 +6055,6 @@ components:
required: required:
- data - data
title: OpenAIListModelsResponse title: OpenAIListModelsResponse
ModelType:
type: string
enum:
- llm
- embedding
- rerank
title: ModelType
description: >-
Enumeration of supported model types in Llama Stack.
RegisterModelRequest:
type: object
properties:
model_id:
type: string
description: The identifier of the model to register.
provider_model_id:
type: string
description: >-
The identifier of the model in the provider.
provider_id:
type: string
description: The identifier of the provider.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Any additional metadata for this model.
model_type:
$ref: '#/components/schemas/ModelType'
description: The type of model to register.
additionalProperties: false
required:
- model_id
title: RegisterModelRequest
Model: Model:
type: object type: object
properties: properties:
@ -6377,6 +6112,15 @@ components:
title: Model title: Model
description: >- description: >-
A model resource representing an AI model registered in Llama Stack. A model resource representing an AI model registered in Llama Stack.
ModelType:
type: string
enum:
- llm
- embedding
- rerank
title: ModelType
description: >-
Enumeration of supported model types in Llama Stack.
RunModerationRequest: RunModerationRequest:
type: object type: object
properties: properties:
@ -6882,6 +6626,11 @@ components:
type: string type: string
description: >- description: >-
(Optional) System message inserted into the model's context (Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
input: input:
type: array type: array
items: items:
@ -7240,6 +6989,11 @@ components:
(Optional) Additional fields to include in the response. (Optional) Additional fields to include in the response.
max_infer_iters: max_infer_iters:
type: integer type: integer
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response.
additionalProperties: false additionalProperties: false
required: required:
- input - input
@ -7321,6 +7075,11 @@ components:
type: string type: string
description: >- description: >-
(Optional) System message inserted into the model's context (Optional) System message inserted into the model's context
max_tool_calls:
type: integer
description: >-
(Optional) Max number of total calls to built-in tools that can be processed
in a response
additionalProperties: false additionalProperties: false
required: required:
- created_at - created_at
@ -9115,61 +8874,6 @@ components:
required: required:
- data - data
title: ListScoringFunctionsResponse title: ListScoringFunctionsResponse
ParamType:
oneOf:
- $ref: '#/components/schemas/StringType'
- $ref: '#/components/schemas/NumberType'
- $ref: '#/components/schemas/BooleanType'
- $ref: '#/components/schemas/ArrayType'
- $ref: '#/components/schemas/ObjectType'
- $ref: '#/components/schemas/JsonType'
- $ref: '#/components/schemas/UnionType'
- $ref: '#/components/schemas/ChatCompletionInputType'
- $ref: '#/components/schemas/CompletionInputType'
discriminator:
propertyName: type
mapping:
string: '#/components/schemas/StringType'
number: '#/components/schemas/NumberType'
boolean: '#/components/schemas/BooleanType'
array: '#/components/schemas/ArrayType'
object: '#/components/schemas/ObjectType'
json: '#/components/schemas/JsonType'
union: '#/components/schemas/UnionType'
chat_completion_input: '#/components/schemas/ChatCompletionInputType'
completion_input: '#/components/schemas/CompletionInputType'
RegisterScoringFunctionRequest:
type: object
properties:
scoring_fn_id:
type: string
description: >-
The ID of the scoring function to register.
description:
type: string
description: The description of the scoring function.
return_type:
$ref: '#/components/schemas/ParamType'
description: The return type of the scoring function.
provider_scoring_fn_id:
type: string
description: >-
The ID of the provider scoring function to use for the scoring function.
provider_id:
type: string
description: >-
The ID of the provider to use for the scoring function.
params:
$ref: '#/components/schemas/ScoringFnParams'
description: >-
The parameters for the scoring function for benchmark eval, these can
be overridden for app eval.
additionalProperties: false
required:
- scoring_fn_id
- description
- return_type
title: RegisterScoringFunctionRequest
ScoreRequest: ScoreRequest:
type: object type: object
properties: properties:
@ -9345,35 +9049,6 @@ components:
required: required:
- data - data
title: ListShieldsResponse title: ListShieldsResponse
RegisterShieldRequest:
type: object
properties:
shield_id:
type: string
description: >-
The identifier of the shield to register.
provider_shield_id:
type: string
description: >-
The identifier of the shield in the provider.
provider_id:
type: string
description: The identifier of the provider.
params:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The parameters of the shield.
additionalProperties: false
required:
- shield_id
title: RegisterShieldRequest
InvokeToolRequest: InvokeToolRequest:
type: object type: object
properties: properties:
@ -9634,37 +9309,6 @@ components:
title: ListToolGroupsResponse title: ListToolGroupsResponse
description: >- description: >-
Response containing a list of tool groups. Response containing a list of tool groups.
RegisterToolGroupRequest:
type: object
properties:
toolgroup_id:
type: string
description: The ID of the tool group to register.
provider_id:
type: string
description: >-
The ID of the provider to use for the tool group.
mcp_endpoint:
$ref: '#/components/schemas/URL'
description: >-
The MCP endpoint to use for the tool group.
args:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
A dictionary of arguments to pass to the tool group.
additionalProperties: false
required:
- toolgroup_id
- provider_id
title: RegisterToolGroupRequest
Chunk: Chunk:
type: object type: object
properties: properties:
@ -10465,41 +10109,35 @@ components:
title: VectorStoreContent title: VectorStoreContent
description: >- description: >-
Content item from a vector store file or search result. Content item from a vector store file or search result.
VectorStoreFileContentsResponse: VectorStoreFileContentResponse:
type: object type: object
properties: properties:
file_id: object:
type: string type: string
description: Unique identifier for the file const: vector_store.file_content.page
filename: default: vector_store.file_content.page
type: string
description: Name of the file
attributes:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >- description: >-
Key-value attributes associated with the file The object type, which is always `vector_store.file_content.page`
content: data:
type: array type: array
items: items:
$ref: '#/components/schemas/VectorStoreContent' $ref: '#/components/schemas/VectorStoreContent'
description: List of content items from the file description: Parsed content of the file
has_more:
type: boolean
description: >-
Indicates if there are more content pages to fetch
next_page:
type: string
description: The token for the next page, if any
additionalProperties: false additionalProperties: false
required: required:
- file_id - object
- filename - data
- attributes - has_more
- content title: VectorStoreFileContentResponse
title: VectorStoreFileContentsResponse
description: >- description: >-
Response from retrieving the contents of a vector store file. Represents the parsed content of a vector store file.
OpenaiSearchVectorStoreRequest: OpenaiSearchVectorStoreRequest:
type: object type: object
properties: properties:
@ -10816,68 +10454,6 @@ components:
- data - data
title: ListDatasetsResponse title: ListDatasetsResponse
description: Response from listing datasets. description: Response from listing datasets.
DataSource:
oneOf:
- $ref: '#/components/schemas/URIDataSource'
- $ref: '#/components/schemas/RowsDataSource'
discriminator:
propertyName: type
mapping:
uri: '#/components/schemas/URIDataSource'
rows: '#/components/schemas/RowsDataSource'
RegisterDatasetRequest:
type: object
properties:
purpose:
type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
description: >-
The purpose of the dataset. One of: - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. {
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
contains a question column and an answer column for evaluation. { "question":
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
The dataset contains a messages column with list of messages and an answer
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
Doe. How can I help you today?"}, {"role": "user", "content": "What's
my name?"}, ], "answer": "John Doe" }
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Ensure that the data source schema is
compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The metadata for the dataset. - E.g. {"description": "My dataset"}.
dataset_id:
type: string
description: >-
The ID of the dataset. If not provided, an ID will be generated.
additionalProperties: false
required:
- purpose
- source
title: RegisterDatasetRequest
Benchmark: Benchmark:
type: object type: object
properties: properties:
@ -10945,47 +10521,6 @@ components:
required: required:
- data - data
title: ListBenchmarksResponse title: ListBenchmarksResponse
RegisterBenchmarkRequest:
type: object
properties:
benchmark_id:
type: string
description: The ID of the benchmark to register.
dataset_id:
type: string
description: >-
The ID of the dataset to use for the benchmark.
scoring_functions:
type: array
items:
type: string
description: >-
The scoring functions to use for the benchmark.
provider_benchmark_id:
type: string
description: >-
The ID of the provider benchmark to use for the benchmark.
provider_id:
type: string
description: >-
The ID of the provider to use for the benchmark.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The metadata to use for the benchmark.
additionalProperties: false
required:
- benchmark_id
- dataset_id
- scoring_functions
title: RegisterBenchmarkRequest
BenchmarkConfig: BenchmarkConfig:
type: object type: object
properties: properties:
@ -11847,6 +11382,109 @@ components:
- hyperparam_search_config - hyperparam_search_config
- logger_config - logger_config
title: SupervisedFineTuneRequest title: SupervisedFineTuneRequest
DataSource:
oneOf:
- $ref: '#/components/schemas/URIDataSource'
- $ref: '#/components/schemas/RowsDataSource'
discriminator:
propertyName: type
mapping:
uri: '#/components/schemas/URIDataSource'
rows: '#/components/schemas/RowsDataSource'
RegisterDatasetRequest:
type: object
properties:
purpose:
type: string
enum:
- post-training/messages
- eval/question-answer
- eval/messages-answer
description: >-
The purpose of the dataset. One of: - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. {
"messages": [ {"role": "user", "content": "Hello, world!"}, {"role": "assistant",
"content": "Hello, world!"}, ] } - "eval/question-answer": The dataset
contains a question column and an answer column for evaluation. { "question":
"What is the capital of France?", "answer": "Paris" } - "eval/messages-answer":
The dataset contains a messages column with list of messages and an answer
column for evaluation. { "messages": [ {"role": "user", "content": "Hello,
my name is John Doe."}, {"role": "assistant", "content": "Hello, John
Doe. How can I help you today?"}, {"role": "user", "content": "What's
my name?"}, ], "answer": "John Doe" }
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Ensure that the data source schema is
compatible with the purpose of the dataset. Examples: - { "type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}"
} - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"
} - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content":
"Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ]
} ] }
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
The metadata for the dataset. - E.g. {"description": "My dataset"}.
dataset_id:
type: string
description: >-
The ID of the dataset. If not provided, an ID will be generated.
additionalProperties: false
required:
- purpose
- source
title: RegisterDatasetRequest
RegisterBenchmarkRequest:
type: object
properties:
benchmark_id:
type: string
description: The ID of the benchmark to register.
dataset_id:
type: string
description: >-
The ID of the dataset to use for the benchmark.
scoring_functions:
type: array
items:
type: string
description: >-
The scoring functions to use for the benchmark.
provider_benchmark_id:
type: string
description: >-
The ID of the provider benchmark to use for the benchmark.
provider_id:
type: string
description: >-
The ID of the provider to use for the benchmark.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: The metadata to use for the benchmark.
additionalProperties: false
required:
- benchmark_id
- dataset_id
- scoring_functions
title: RegisterBenchmarkRequest
responses: responses:
BadRequest400: BadRequest400:
description: The request was invalid or malformed description: The request was invalid or malformed

View file

@ -298,6 +298,7 @@ exclude = [
"^src/llama_stack/providers/remote/agents/sample/", "^src/llama_stack/providers/remote/agents/sample/",
"^src/llama_stack/providers/remote/datasetio/huggingface/", "^src/llama_stack/providers/remote/datasetio/huggingface/",
"^src/llama_stack/providers/remote/datasetio/nvidia/", "^src/llama_stack/providers/remote/datasetio/nvidia/",
"^src/llama_stack/providers/remote/inference/oci/",
"^src/llama_stack/providers/remote/inference/bedrock/", "^src/llama_stack/providers/remote/inference/bedrock/",
"^src/llama_stack/providers/remote/inference/nvidia/", "^src/llama_stack/providers/remote/inference/nvidia/",
"^src/llama_stack/providers/remote/inference/passthrough/", "^src/llama_stack/providers/remote/inference/passthrough/",

View file

@ -87,6 +87,7 @@ class Agents(Protocol):
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation." "List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
), ),
] = None, ] = None,
max_tool_calls: int | None = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a model response. """Create a model response.
@ -97,6 +98,7 @@ class Agents(Protocol):
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation. :param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response. :param include: (Optional) Additional fields to include in the response.
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications. :param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response.
:returns: An OpenAIResponseObject. :returns: An OpenAIResponseObject.
""" """
... ...

View file

@ -594,6 +594,7 @@ class OpenAIResponseObject(BaseModel):
:param truncation: (Optional) Truncation strategy applied to the response :param truncation: (Optional) Truncation strategy applied to the response
:param usage: (Optional) Token usage information for the response :param usage: (Optional) Token usage information for the response
:param instructions: (Optional) System message inserted into the model's context :param instructions: (Optional) System message inserted into the model's context
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response
""" """
created_at: int created_at: int
@ -615,6 +616,7 @@ class OpenAIResponseObject(BaseModel):
truncation: str | None = None truncation: str | None = None
usage: OpenAIResponseUsage | None = None usage: OpenAIResponseUsage | None = None
instructions: str | None = None instructions: str | None = None
max_tool_calls: int | None = None
@json_schema_type @json_schema_type

View file

@ -74,7 +74,7 @@ class Benchmarks(Protocol):
""" """
... ...
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA) @webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA, deprecated=True)
async def register_benchmark( async def register_benchmark(
self, self,
benchmark_id: str, benchmark_id: str,
@ -95,7 +95,7 @@ class Benchmarks(Protocol):
""" """
... ...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA) @webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA, deprecated=True)
async def unregister_benchmark(self, benchmark_id: str) -> None: async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark. """Unregister a benchmark.

View file

@ -146,7 +146,7 @@ class ListDatasetsResponse(BaseModel):
class Datasets(Protocol): class Datasets(Protocol):
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA) @webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA, deprecated=True)
async def register_dataset( async def register_dataset(
self, self,
purpose: DatasetPurpose, purpose: DatasetPurpose,
@ -235,7 +235,7 @@ class Datasets(Protocol):
""" """
... ...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA) @webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA, deprecated=True)
async def unregister_dataset( async def unregister_dataset(
self, self,
dataset_id: str, dataset_id: str,

View file

@ -1,43 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
)
class LogEvent:
def __init__(
self,
content: str = "",
end: str = "\n",
color="white",
):
self.content = content
self.color = color
self.end = "\n" if end is None else end
def print(self, flush=True):
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
class EventLogger:
async def log(self, event_generator):
async for chunk in event_generator:
if isinstance(chunk, ChatCompletionResponseStreamChunk):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == ChatCompletionResponseEventType.progress:
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == ChatCompletionResponseEventType.complete:
yield LogEvent("")
else:
yield LogEvent("Assistant> ", color="cyan", end="")
yield LogEvent(chunk.completion_message.content, color="yellow")

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from enum import Enum from enum import Enum, StrEnum
from typing import ( from typing import (
Annotated, Annotated,
Any, Any,
@ -15,28 +15,18 @@ from typing import (
) )
from fastapi import Body from fastapi import Body
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field
from typing_extensions import TypedDict from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.responses import MetricResponseMixin, Order from llama_stack.apis.common.responses import (
Order,
)
from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.common.tracing import telemetry_traceable
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
register_schema(ToolCall)
register_schema(ToolDefinition)
from enum import StrEnum
@json_schema_type @json_schema_type
class GreedySamplingStrategy(BaseModel): class GreedySamplingStrategy(BaseModel):
@ -201,58 +191,6 @@ class ToolResponseMessage(BaseModel):
content: InterleavedContent content: InterleavedContent
@json_schema_type
class CompletionMessage(BaseModel):
"""A message containing the model's (assistant) response in a chat conversation.
:param role: Must be "assistant" to identify this as the model's response
:param content: The content of the model's response
:param stop_reason: Reason why the model stopped generating. Options are:
- `StopReason.end_of_turn`: The model finished generating the entire response.
- `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response.
- `StopReason.out_of_tokens`: The model ran out of token budget.
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
"""
role: Literal["assistant"] = "assistant"
content: InterleavedContent
stop_reason: StopReason
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
Message = Annotated[
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
Field(discriminator="role"),
]
register_schema(Message, name="Message")
@json_schema_type
class ToolResponse(BaseModel):
"""Response from a tool invocation.
:param call_id: Unique identifier for the tool call this response is for
:param tool_name: Name of the tool that was invoked
:param content: The response content from the tool
:param metadata: (Optional) Additional metadata about the tool response
"""
call_id: str
tool_name: BuiltinTool | str
content: InterleavedContent
metadata: dict[str, Any] | None = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
class ToolChoice(Enum): class ToolChoice(Enum):
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model. """Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.
@ -289,22 +227,6 @@ class ChatCompletionResponseEventType(Enum):
progress = "progress" progress = "progress"
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""An event during chat completion generation.
:param event_type: Type of the event
:param delta: Content generated since last event. This can be one or more tokens, or a tool call.
:param logprobs: Optional log probabilities for generated tokens
:param stop_reason: Optional reason why generation stopped, if complete
"""
event_type: ChatCompletionResponseEventType
delta: ContentDelta
logprobs: list[TokenLogProbs] | None = None
stop_reason: StopReason | None = None
class ResponseFormatType(StrEnum): class ResponseFormatType(StrEnum):
"""Types of formats for structured (guided) decoding. """Types of formats for structured (guided) decoding.
@ -357,34 +279,6 @@ class CompletionRequest(BaseModel):
logprobs: LogProbConfig | None = None logprobs: LogProbConfig | None = None
@json_schema_type
class CompletionResponse(MetricResponseMixin):
"""Response from a completion request.
:param content: The generated completion text
:param stop_reason: Reason why generation stopped
:param logprobs: Optional log probabilities for generated tokens
"""
content: str
stop_reason: StopReason
logprobs: list[TokenLogProbs] | None = None
@json_schema_type
class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens.
:param stop_reason: Optional reason why generation stopped, if complete
:param logprobs: Optional log probabilities for generated tokens
"""
delta: str
stop_reason: StopReason | None = None
logprobs: list[TokenLogProbs] | None = None
class SystemMessageBehavior(Enum): class SystemMessageBehavior(Enum):
"""Config for how to override the default system prompt. """Config for how to override the default system prompt.
@ -398,70 +292,6 @@ class SystemMessageBehavior(Enum):
replace = "replace" replace = "replace"
@json_schema_type
class ToolConfig(BaseModel):
"""Configuration for tool use.
:param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
:param system_message_behavior: (Optional) Config for how to override the default system prompt.
- `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt.
- `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string
'{{function_definitions}}' to indicate where the function definitions should be inserted.
"""
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
def model_post_init(self, __context: Any) -> None:
if isinstance(self.tool_choice, str):
try:
self.tool_choice = ToolChoice[self.tool_choice]
except KeyError:
pass
# This is an internally used class
@json_schema_type
class ChatCompletionRequest(BaseModel):
model: str
messages: list[Message]
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
response_format: ResponseFormat | None = None
stream: bool | None = False
logprobs: LogProbConfig | None = None
@json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed chat completion response.
:param event: The event containing the new content
"""
event: ChatCompletionResponseEvent
@json_schema_type
class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request.
:param completion_message: The complete response message
:param logprobs: Optional log probabilities for generated tokens
"""
completion_message: CompletionMessage
logprobs: list[TokenLogProbs] | None = None
@json_schema_type @json_schema_type
class EmbeddingsResponse(BaseModel): class EmbeddingsResponse(BaseModel):
"""Response containing generated embeddings. """Response containing generated embeddings.

View file

@ -76,7 +76,7 @@ class Inspect(Protocol):
List all available API routes with their methods and implementing providers. List all available API routes with their methods and implementing providers.
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes. :param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns all non-deprecated routes.
:returns: Response containing information about all available routes. :returns: Response containing information about all available routes.
""" """
... ...

View file

@ -136,7 +136,7 @@ class Models(Protocol):
""" """
... ...
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
async def register_model( async def register_model(
self, self,
model_id: str, model_id: str,
@ -158,7 +158,7 @@ class Models(Protocol):
""" """
... ...
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1) @webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
async def unregister_model( async def unregister_model(
self, self,
model_id: str, model_id: str,

View file

@ -178,7 +178,7 @@ class ScoringFunctions(Protocol):
""" """
... ...
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
async def register_scoring_function( async def register_scoring_function(
self, self,
scoring_fn_id: str, scoring_fn_id: str,
@ -199,7 +199,9 @@ class ScoringFunctions(Protocol):
""" """
... ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1) @webmethod(
route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
)
async def unregister_scoring_function(self, scoring_fn_id: str) -> None: async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
"""Unregister a scoring function. """Unregister a scoring function.

View file

@ -67,7 +67,7 @@ class Shields(Protocol):
""" """
... ...
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
async def register_shield( async def register_shield(
self, self,
shield_id: str, shield_id: str,
@ -85,7 +85,7 @@ class Shields(Protocol):
""" """
... ...
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1) @webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
async def unregister_shield(self, identifier: str) -> None: async def unregister_shield(self, identifier: str) -> None:
"""Unregister a shield. """Unregister a shield.

View file

@ -109,7 +109,7 @@ class ListToolDefsResponse(BaseModel):
@runtime_checkable @runtime_checkable
@telemetry_traceable @telemetry_traceable
class ToolGroups(Protocol): class ToolGroups(Protocol):
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
async def register_tool_group( async def register_tool_group(
self, self,
toolgroup_id: str, toolgroup_id: str,
@ -167,7 +167,7 @@ class ToolGroups(Protocol):
""" """
... ...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1) @webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
async def unregister_toolgroup( async def unregister_toolgroup(
self, self,
toolgroup_id: str, toolgroup_id: str,

View file

@ -396,19 +396,19 @@ class VectorStoreListFilesResponse(BaseModel):
@json_schema_type @json_schema_type
class VectorStoreFileContentsResponse(BaseModel): class VectorStoreFileContentResponse(BaseModel):
"""Response from retrieving the contents of a vector store file. """Represents the parsed content of a vector store file.
:param file_id: Unique identifier for the file :param object: The object type, which is always `vector_store.file_content.page`
:param filename: Name of the file :param data: Parsed content of the file
:param attributes: Key-value attributes associated with the file :param has_more: Indicates if there are more content pages to fetch
:param content: List of content items from the file :param next_page: The token for the next page, if any
""" """
file_id: str object: Literal["vector_store.file_content.page"] = "vector_store.file_content.page"
filename: str data: list[VectorStoreContent]
attributes: dict[str, Any] has_more: bool
content: list[VectorStoreContent] next_page: str | None = None
@json_schema_type @json_schema_type
@ -732,12 +732,12 @@ class VectorIO(Protocol):
self, self,
vector_store_id: str, vector_store_id: str,
file_id: str, file_id: str,
) -> VectorStoreFileContentsResponse: ) -> VectorStoreFileContentResponse:
"""Retrieves the contents of a vector store file. """Retrieves the contents of a vector store file.
:param vector_store_id: The ID of the vector store containing the file to retrieve. :param vector_store_id: The ID of the vector store containing the file to retrieve.
:param file_id: The ID of the file to retrieve. :param file_id: The ID of the file to retrieve.
:returns: A list of InterleavedContent representing the file contents. :returns: A VectorStoreFileContentResponse representing the file contents.
""" """
... ...

View file

@ -15,7 +15,6 @@ from llama_stack.apis.inspect import (
RouteInfo, RouteInfo,
VersionInfo, VersionInfo,
) )
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis from llama_stack.core.external import load_external_apis
from llama_stack.core.server.routes import get_all_api_routes from llama_stack.core.server.routes import get_all_api_routes
@ -46,8 +45,8 @@ class DistributionInspectImpl(Inspect):
# Helper function to determine if a route should be included based on api_filter # Helper function to determine if a route should be included based on api_filter
def should_include_route(webmethod) -> bool: def should_include_route(webmethod) -> bool:
if api_filter is None: if api_filter is None:
# Default: only non-deprecated v1 APIs # Default: only non-deprecated APIs
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1 return not webmethod.deprecated
elif api_filter == "deprecated": elif api_filter == "deprecated":
# Special filter: show deprecated routes regardless of their actual level # Special filter: show deprecated routes regardless of their actual level
return bool(webmethod.deprecated) return bool(webmethod.deprecated)

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from llama_stack.apis.inference import Message from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
@ -52,7 +52,7 @@ class SafetyRouter(Safety):
async def run_shield( async def run_shield(
self, self,
shield_id: str, shield_id: str,
messages: list[Message], messages: list[OpenAIMessageParam],
params: dict[str, Any] = None, params: dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}") logger.debug(f"SafetyRouter.run_shield: {shield_id}")

View file

@ -24,7 +24,7 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategyStaticConfig, VectorStoreChunkingStrategyStaticConfig,
VectorStoreDeleteResponse, VectorStoreDeleteResponse,
VectorStoreFileBatchObject, VectorStoreFileBatchObject,
VectorStoreFileContentsResponse, VectorStoreFileContentResponse,
VectorStoreFileDeleteResponse, VectorStoreFileDeleteResponse,
VectorStoreFileObject, VectorStoreFileObject,
VectorStoreFilesListInBatchResponse, VectorStoreFilesListInBatchResponse,
@ -338,7 +338,7 @@ class VectorIORouter(VectorIO):
self, self,
vector_store_id: str, vector_store_id: str,
file_id: str, file_id: str,
) -> VectorStoreFileContentsResponse: ) -> VectorStoreFileContentResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id) provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents( return await provider.openai_retrieve_vector_store_file_contents(

View file

@ -15,7 +15,7 @@ from llama_stack.apis.vector_io.vector_io import (
SearchRankingOptions, SearchRankingOptions,
VectorStoreChunkingStrategy, VectorStoreChunkingStrategy,
VectorStoreDeleteResponse, VectorStoreDeleteResponse,
VectorStoreFileContentsResponse, VectorStoreFileContentResponse,
VectorStoreFileDeleteResponse, VectorStoreFileDeleteResponse,
VectorStoreFileObject, VectorStoreFileObject,
VectorStoreFileStatus, VectorStoreFileStatus,
@ -195,7 +195,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
self, self,
vector_store_id: str, vector_store_id: str,
file_id: str, file_id: str,
) -> VectorStoreFileContentsResponse: ) -> VectorStoreFileContentResponse:
await self.assert_action_allowed("read", "vector_store", vector_store_id) await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents( return await provider.openai_retrieve_vector_store_file_contents(

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .oci import get_distribution_template # noqa: F401

View file

@ -0,0 +1,35 @@
version: 2
distribution_spec:
description: Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM
inference with scalable cloud services
providers:
inference:
- provider_type: remote::oci
vector_io:
- provider_type: inline::faiss
- provider_type: remote::chromadb
- provider_type: remote::pgvector
safety:
- provider_type: inline::llama-guard
agents:
- provider_type: inline::meta-reference
eval:
- provider_type: inline::meta-reference
datasetio:
- provider_type: remote::huggingface
- provider_type: inline::localfs
scoring:
- provider_type: inline::basic
- provider_type: inline::llm-as-judge
- provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
files:
- provider_type: inline::localfs
image_type: venv
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -0,0 +1,140 @@
---
orphan: true
---
# OCI Distribution
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} {{ model.doc_string }}`
{% endfor %}
{% endif %}
## Prerequisites
### Oracle Cloud Infrastructure Setup
Before using the OCI Generative AI distribution, ensure you have:
1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/)
2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy
3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models
4. **Authentication**: Configure authentication using either:
- **Instance Principal** (recommended for cloud-hosted deployments)
- **API Key** (for on-premises or development environments)
### Authentication Methods
#### Instance Principal Authentication (Recommended)
Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments.
Requirements:
- Instance must be running in an Oracle Cloud Infrastructure compartment
- Instance must have appropriate IAM policies to access Generative AI services
#### API Key Authentication
For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file.
### Required IAM Policies
Ensure your OCI user or instance has the following policy statements:
```
Allow group <group_name> to use generative-ai-inference-endpoints in compartment <compartment_name>
Allow group <group_name> to manage generative-ai-inference-endpoints in compartment <compartment_name>
```
## Supported Services
### Inference: OCI Generative AI
Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports:
- **Chat Completions**: Conversational AI with context awareness
- **Text Generation**: Complete prompts and generate text content
#### Available Models
Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models.
### Safety: Llama Guard
For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide:
- Content filtering and moderation
- Policy compliance checking
- Harmful content detection
### Vector Storage: Multiple Options
The distribution supports several vector storage providers:
- **FAISS**: Local in-memory vector search
- **ChromaDB**: Distributed vector database
- **PGVector**: PostgreSQL with vector extensions
### Additional Services
- **Dataset I/O**: Local filesystem and Hugging Face integration
- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities
- **Evaluation**: Meta reference evaluation framework
## Running Llama Stack with OCI
You can run the OCI distribution via Docker or local virtual environment.
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
```bash
OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci
```
### Configuration Examples
#### Using Instance Principal (Recommended for Production)
```bash
export OCI_AUTH_TYPE=instance_principal
export OCI_REGION=us-chicago-1
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..<your-compartment-id>
```
#### Using API Key Authentication (Development)
```bash
export OCI_AUTH_TYPE=config_file
export OCI_CONFIG_FILE_PATH=~/.oci/config
export OCI_CLI_PROFILE=DEFAULT
export OCI_REGION=us-chicago-1
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id
```
## Regional Endpoints
OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit:
https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions
## Troubleshooting
### Common Issues
1. **Authentication Errors**: Verify your OCI credentials and IAM policies
2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region
3. **Permission Denied**: Check compartment permissions and Generative AI service access
4. **Region Unavailable**: Verify the specified region supports Generative AI services
### Getting Help
For additional support:
- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)
- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues)

View file

@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.oci.config import OCIConfig
def get_distribution_template(name: str = "oci") -> DistributionTemplate:
providers = {
"inference": [BuildProvider(provider_type="remote::oci")],
"vector_io": [
BuildProvider(provider_type="inline::faiss"),
BuildProvider(provider_type="remote::chromadb"),
BuildProvider(provider_type="remote::pgvector"),
],
"safety": [BuildProvider(provider_type="inline::llama-guard")],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [
BuildProvider(provider_type="remote::huggingface"),
BuildProvider(provider_type="inline::localfs"),
],
"scoring": [
BuildProvider(provider_type="inline::basic"),
BuildProvider(provider_type="inline::llm-as-judge"),
BuildProvider(provider_type="inline::braintrust"),
],
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
}
inference_provider = Provider(
provider_id="oci",
provider_type="remote::oci",
config=OCIConfig.sample_run_config(),
)
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
]
return DistributionTemplate(
name=name,
distro_type="remote_hosted",
description="Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM inference with scalable cloud services",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider],
"files": [files_provider],
},
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"OCI_AUTH_TYPE": (
"instance_principal",
"OCI authentication type (instance_principal or config_file)",
),
"OCI_REGION": (
"",
"OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1)",
),
"OCI_COMPARTMENT_OCID": (
"",
"OCI compartment ID for the Generative AI service",
),
"OCI_CONFIG_FILE_PATH": (
"~/.oci/config",
"OCI config file path (required if OCI_AUTH_TYPE is config_file)",
),
"OCI_CLI_PROFILE": (
"DEFAULT",
"OCI CLI profile name to use from config file",
),
},
)

View file

@ -0,0 +1,136 @@
version: 2
image_name: oci
apis:
- agents
- datasetio
- eval
- files
- inference
- safety
- scoring
- tool_runtime
- vector_io
providers:
inference:
- provider_id: oci
provider_type: remote::oci
config:
oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal}
oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}
oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT}
oci_region: ${env.OCI_REGION:=us-ashburn-1}
oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
persistence:
namespace: vector_io::faiss
backend: kv_default
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
namespace: eval
backend: kv_default
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
namespace: datasetio::huggingface
backend: kv_default
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
namespace: datasetio::localfs
backend: kv_default
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/oci/files}
metadata_store:
table_name: files_metadata
backend: sql_default
storage:
backends:
kv_default:
type: kv_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/kvstore.db
sql_default:
type: sql_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/sql_store.db
stores:
metadata:
namespace: registry
backend: kv_default
inference:
table_name: inference_store
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_default
prompts:
namespace: prompts
backend: kv_default
registered_resources:
models: []
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
server:
port: 8321
telemetry:
enabled: true

View file

@ -26,8 +26,10 @@ from fairscale.nn.model_parallel.initialize import (
) )
from termcolor import cprint from termcolor import cprint
from llama_stack.models.llama.datatypes import ToolPromptFormat
from ..checkpoint import maybe_reshard_state_dict from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
from .args import ModelArgs from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput from .chat_format import ChatFormat, LLMInput
from .model import Transformer from .model import Transformer

View file

@ -15,13 +15,10 @@ from pathlib import Path
from termcolor import colored from termcolor import colored
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat
from ..datatypes import ( from ..datatypes import (
BuiltinTool,
RawMessage, RawMessage,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
) )
from . import template_data from . import template_data
from .chat_format import ChatFormat from .chat_format import ChatFormat

View file

@ -15,7 +15,7 @@ import textwrap
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from llama_stack.apis.inference import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
ToolDefinition, ToolDefinition,
) )

View file

@ -8,8 +8,9 @@ import json
import re import re
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolPromptFormat
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat from ..datatypes import RecursiveType
logger = get_logger(name=__name__, category="models::llama") logger = get_logger(name=__name__, category="models::llama")

View file

@ -13,7 +13,7 @@
import textwrap import textwrap
from llama_stack.apis.inference import ToolDefinition from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.models.llama.llama3.prompt_templates.base import ( from llama_stack.models.llama.llama3.prompt_templates.base import (
PromptTemplate, PromptTemplate,
PromptTemplateGeneratorBase, PromptTemplateGeneratorBase,

View file

@ -102,6 +102,7 @@ class MetaReferenceAgentsImpl(Agents):
include: list[str] | None = None, include: list[str] | None = None,
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrail] | None = None, guardrails: list[ResponseGuardrail] | None = None,
max_tool_calls: int | None = None,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized" assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
result = await self.openai_responses_impl.create_openai_response( result = await self.openai_responses_impl.create_openai_response(
@ -119,6 +120,7 @@ class MetaReferenceAgentsImpl(Agents):
include, include,
max_infer_iters, max_infer_iters,
guardrails, guardrails,
max_tool_calls,
) )
return result # type: ignore[no-any-return] return result # type: ignore[no-any-return]

View file

@ -255,6 +255,7 @@ class OpenAIResponsesImpl:
include: list[str] | None = None, include: list[str] | None = None,
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
guardrails: list[str | ResponseGuardrailSpec] | None = None, guardrails: list[str | ResponseGuardrailSpec] | None = None,
max_tool_calls: int | None = None,
): ):
stream = bool(stream) stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
@ -270,6 +271,9 @@ class OpenAIResponsesImpl:
if not conversation.startswith("conv_"): if not conversation.startswith("conv_"):
raise InvalidConversationIdError(conversation) raise InvalidConversationIdError(conversation)
if max_tool_calls is not None and max_tool_calls < 1:
raise ValueError(f"Invalid {max_tool_calls=}; should be >= 1")
stream_gen = self._create_streaming_response( stream_gen = self._create_streaming_response(
input=input, input=input,
conversation=conversation, conversation=conversation,
@ -282,6 +286,7 @@ class OpenAIResponsesImpl:
tools=tools, tools=tools,
max_infer_iters=max_infer_iters, max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids, guardrail_ids=guardrail_ids,
max_tool_calls=max_tool_calls,
) )
if stream: if stream:
@ -331,6 +336,7 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None, tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None, guardrail_ids: list[str] | None = None,
max_tool_calls: int | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]: ) -> AsyncIterator[OpenAIResponseObjectStream]:
# These should never be None when called from create_openai_response (which sets defaults) # These should never be None when called from create_openai_response (which sets defaults)
# but we assert here to help mypy understand the types # but we assert here to help mypy understand the types
@ -373,6 +379,7 @@ class OpenAIResponsesImpl:
safety_api=self.safety_api, safety_api=self.safety_api,
guardrail_ids=guardrail_ids, guardrail_ids=guardrail_ids,
instructions=instructions, instructions=instructions,
max_tool_calls=max_tool_calls,
) )
# Stream the response # Stream the response

View file

@ -115,6 +115,7 @@ class StreamingResponseOrchestrator:
safety_api, safety_api,
guardrail_ids: list[str] | None = None, guardrail_ids: list[str] | None = None,
prompt: OpenAIResponsePrompt | None = None, prompt: OpenAIResponsePrompt | None = None,
max_tool_calls: int | None = None,
): ):
self.inference_api = inference_api self.inference_api = inference_api
self.ctx = ctx self.ctx = ctx
@ -126,6 +127,10 @@ class StreamingResponseOrchestrator:
self.safety_api = safety_api self.safety_api = safety_api
self.guardrail_ids = guardrail_ids or [] self.guardrail_ids = guardrail_ids or []
self.prompt = prompt self.prompt = prompt
# System message that is inserted into the model's context
self.instructions = instructions
# Max number of total calls to built-in tools that can be processed in a response
self.max_tool_calls = max_tool_calls
self.sequence_number = 0 self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing # Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ( self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
@ -139,8 +144,8 @@ class StreamingResponseOrchestrator:
self.accumulated_usage: OpenAIResponseUsage | None = None self.accumulated_usage: OpenAIResponseUsage | None = None
# Track if we've sent a refusal response # Track if we've sent a refusal response
self.violation_detected = False self.violation_detected = False
# system message that is inserted into the model's context # Track total calls made to built-in tools
self.instructions = instructions self.accumulated_builtin_tool_calls = 0
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream: async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content.""" """Create a refusal response to replace streaming content."""
@ -186,6 +191,7 @@ class StreamingResponseOrchestrator:
usage=self.accumulated_usage, usage=self.accumulated_usage,
instructions=self.instructions, instructions=self.instructions,
prompt=self.prompt, prompt=self.prompt,
max_tool_calls=self.max_tool_calls,
) )
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
@ -894,6 +900,11 @@ class StreamingResponseOrchestrator:
"""Coordinate execution of both function and non-function tool calls.""" """Coordinate execution of both function and non-function tool calls."""
# Execute non-function tool calls # Execute non-function tool calls
for tool_call in non_function_tool_calls: for tool_call in non_function_tool_calls:
# Check if total calls made to built-in and mcp tools exceed max_tool_calls
if self.max_tool_calls is not None and self.accumulated_builtin_tool_calls >= self.max_tool_calls:
logger.info(f"Ignoring built-in and mcp tool call since reached the limit of {self.max_tool_calls=}.")
break
# Find the item_id for this tool call # Find the item_id for this tool call
matching_item_id = None matching_item_id = None
for index, item_id in completion_result_data.tool_call_item_ids.items(): for index, item_id in completion_result_data.tool_call_item_ids.items():
@ -974,6 +985,9 @@ class StreamingResponseOrchestrator:
if tool_response_message: if tool_response_message:
next_turn_messages.append(tool_response_message) next_turn_messages.append(tool_response_message)
# Track number of calls made to built-in and mcp tools
self.accumulated_builtin_tool_calls += 1
# Execute function tool calls (client-side) # Execute function tool calls (client-side)
for tool_call in function_tool_calls: for tool_call in function_tool_calls:
# Find the item_id for this tool call from our tracking dictionary # Find the item_id for this tool call from our tracking dictionary

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import math import math
from collections.abc import Generator
from typing import Optional from typing import Optional
import torch import torch
@ -14,21 +13,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
GreedySamplingStrategy, GreedySamplingStrategy,
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIResponseFormatJSONSchema,
ResponseFormat, ResponseFormat,
ResponseFormatType,
SamplingParams, SamplingParams,
TopPSamplingStrategy, TopPSamplingStrategy,
) )
from llama_stack.models.llama.datatypes import QuantizationMode from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
from llama_stack.models.llama.llama3.generation import Llama3 from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model, ModelFamily from llama_stack.models.llama.sku_types import Model, ModelFamily
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
get_default_tool_prompt_format,
)
from .common import model_checkpoint_dir from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
@ -106,14 +103,6 @@ def _infer_sampling_params(sampling_params: SamplingParams):
return temperature, top_p return temperature, top_p
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
tool_config = request.tool_config
if tool_config is not None and tool_config.tool_prompt_format is not None:
return tool_config.tool_prompt_format
else:
return get_default_tool_prompt_format(request.model)
class LlamaGenerator: class LlamaGenerator:
def __init__( def __init__(
self, self,
@ -157,55 +146,56 @@ class LlamaGenerator:
self.args = self.inner_generator.args self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter self.formatter = self.inner_generator.formatter
def completion(
self,
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(first_request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
first_request.response_format,
),
)
def chat_completion( def chat_completion(
self, self,
request_batch: list[ChatCompletionRequestWithRawContent], request: OpenAIChatCompletionRequestWithExtraBody,
) -> Generator: raw_messages: list,
first_request = request_batch[0] ):
sampling_params = first_request.sampling_params or SamplingParams() """Generate chat completion using OpenAI request format.
Args:
request: OpenAI chat completion request
raw_messages: Pre-converted list of RawMessage objects
"""
# Determine tool prompt format
tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json
# Prepare sampling params
sampling_params = SamplingParams()
if request.temperature is not None or request.top_p is not None:
sampling_params.strategy = TopPSamplingStrategy(
temperature=request.temperature if request.temperature is not None else 1.0,
top_p=request.top_p if request.top_p is not None else 1.0,
)
if request.max_tokens:
sampling_params.max_tokens = request.max_tokens
max_gen_len = sampling_params.max_tokens max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1 max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params) temperature, top_p = _infer_sampling_params(sampling_params)
# Get logits processor for response format
logits_processor = None
if request.response_format:
if isinstance(request.response_format, OpenAIResponseFormatJSONSchema):
# Extract the actual schema from OpenAIJSONSchema TypedDict
schema_dict = request.response_format.json_schema.get("schema") or {}
json_schema_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema,
json_schema=schema_dict,
)
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
# Generate
yield from self.inner_generator.generate( yield from self.inner_generator.generate(
llm_inputs=[ llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)],
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
for request in request_batch
],
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
logprobs=bool(first_request.logprobs), logprobs=False,
echo=False, echo=False,
logits_processor=get_logits_processor( logits_processor=logits_processor,
self.tokenizer,
self.args.vocab_size,
first_request.response_format,
),
) )

View file

@ -5,12 +5,19 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import time
import uuid
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
InferenceProvider, InferenceProvider,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody, OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionUsage,
OpenAIChoice,
OpenAICompletionRequestWithExtraBody, OpenAICompletionRequestWithExtraBody,
OpenAIUserMessageParam,
ToolChoice,
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
@ -19,12 +26,20 @@ from llama_stack.apis.inference.inference import (
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
JsonCustomToolGenerator,
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
@ -44,6 +59,170 @@ log = get_logger(__name__, category="inference")
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition:
"""Convert OpenAI tool format to ToolDefinition format."""
# OpenAI tools have function.name and function.parameters
return ToolDefinition(
tool_name=tool.function.name,
description=tool.function.description or "",
parameters=tool.function.parameters or {},
)
def _get_tool_choice_prompt(tool_choice, tools) -> str:
"""Generate prompt text for tool_choice behavior."""
if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto":
return ""
elif tool_choice == ToolChoice.required or tool_choice == "required":
return "You MUST use one of the provided functions/tools to answer the user query."
elif tool_choice == ToolChoice.none or tool_choice == "none":
return ""
else:
# Specific tool specified
return f"You MUST use the tool `{tool_choice}` to answer the user query."
def _raw_content_as_str(content) -> str:
"""Convert RawContent to string for system messages."""
if isinstance(content, str):
return content
elif isinstance(content, RawTextItem):
return content.text
elif isinstance(content, list):
return "\n".join(_raw_content_as_str(c) for c in content)
else:
return "<media>"
def _augment_raw_messages_for_tools_llama_3_1(
raw_messages: list[RawMessage],
tools: list,
tool_choice,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions for Llama 3.1 style models."""
messages = raw_messages.copy()
existing_system_message = None
if messages and messages[0].role == "system":
existing_system_message = messages.pop(0)
sys_content = ""
# Add tool definitions first (if present)
if tools:
# Convert OpenAI tools to ToolDefinitions
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
# For OpenAI format, all tools are custom (have string names)
tool_gen = JsonCustomToolGenerator()
tool_template = tool_gen.gen(tool_definitions)
sys_content += tool_template.render()
sys_content += "\n"
# Add default system prompt
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content += default_template.render()
# Add existing system message if present
if existing_system_message:
sys_content += "\n" + _raw_content_as_str(existing_system_message.content)
# Add tool choice prompt if needed
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
sys_content += "\n" + tool_choice_prompt
# Create new system message
new_system_message = RawMessage(
role="system",
content=[RawTextItem(text=sys_content.strip())],
)
return [new_system_message] + messages
def _augment_raw_messages_for_tools_llama_4(
raw_messages: list[RawMessage],
tools: list,
tool_choice,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models."""
messages = raw_messages.copy()
existing_system_message = None
if messages and messages[0].role == "system":
existing_system_message = messages.pop(0)
sys_content = ""
# Add tool definitions if present
if tools:
# Convert OpenAI tools to ToolDefinitions
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
# Use python_list format for Llama 4
tool_gen = PythonListCustomToolGeneratorLlama4()
system_prompt = None
if existing_system_message:
system_prompt = _raw_content_as_str(existing_system_message.content)
tool_template = tool_gen.gen(tool_definitions, system_prompt)
sys_content = tool_template.render()
elif existing_system_message:
# No tools, just use existing system message
sys_content = _raw_content_as_str(existing_system_message.content)
# Add tool choice prompt if needed
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
sys_content += "\n" + tool_choice_prompt
if sys_content:
new_system_message = RawMessage(
role="system",
content=[RawTextItem(text=sys_content.strip())],
)
return [new_system_message] + messages
return messages
def augment_raw_messages_for_tools(
raw_messages: list[RawMessage],
params: OpenAIChatCompletionRequestWithExtraBody,
llama_model,
) -> list[RawMessage]:
"""Augment raw messages with tool definitions based on model family."""
if not params.tools:
return raw_messages
# Determine augmentation strategy based on model family
if llama_model.model_family == ModelFamily.llama3_1 or (
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
):
# Llama 3.1 and Llama 3.2 multimodal use JSON format
return _augment_raw_messages_for_tools_llama_3_1(
raw_messages,
params.tools,
params.tool_choice,
)
elif llama_model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
):
# Llama 3.2/3.3/4 use python_list format
return _augment_raw_messages_for_tools_llama_4(
raw_messages,
params.tools,
params.tool_choice,
)
else:
# Default to Llama 3.1 style
return _augment_raw_messages_for_tools_llama_3_1(
raw_messages,
params.tools,
params.tool_choice,
)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model) return LlamaGenerator(config, model_id, llama_model)
@ -136,11 +315,14 @@ class MetaReferenceInferenceImpl(
self.llama_model = llama_model self.llama_model = llama_model
log.info("Warming up...") log.info("Warming up...")
await self.openai_chat_completion( await self.openai_chat_completion(
params=OpenAIChatCompletionRequestWithExtraBody(
model=model_id, model=model_id,
messages=[{"role": "user", "content": "Hi how are you?"}], messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")],
max_tokens=20, max_tokens=20,
) )
)
log.info("Warmed up!") log.info("Warmed up!")
def check_model(self, request) -> None: def check_model(self, request) -> None:
@ -155,4 +337,207 @@ class MetaReferenceInferenceImpl(
self, self,
params: OpenAIChatCompletionRequestWithExtraBody, params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider") self.check_model(params)
# Convert OpenAI messages to RawMessages
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_openai_message_to_raw_message,
decode_assistant_message,
)
raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
# Augment messages with tool definitions if tools are present
raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model)
# Call generator's chat_completion method (works for both single-GPU and model-parallel)
if isinstance(self.generator, LlamaGenerator):
generator = self.generator.chat_completion(params, raw_messages)
else:
# Model parallel: submit task to process group
generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
# Check if streaming is requested
if params.stream:
return self._stream_chat_completion(generator, params)
# Non-streaming: collect all generated text
generated_text = ""
for result_batch in generator:
for result in result_batch:
if not result.ignore_token and result.source == "output":
generated_text += result.text
# Decode assistant message to extract tool calls and determine stop_reason
# Default to end_of_turn if generation completed normally
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
# Convert tool calls to OpenAI format
openai_tool_calls = None
if decoded_message.tool_calls:
from llama_stack.apis.inference import (
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
)
openai_tool_calls = [
OpenAIChatCompletionToolCall(
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
id=f"call_{uuid.uuid4().hex[:24]}",
type="function",
function=OpenAIChatCompletionToolCallFunction(
name=str(tc.tool_name),
arguments=tc.arguments,
),
)
for tc in decoded_message.tool_calls
]
# Determine finish_reason based on whether tool calls are present
finish_reason = "tool_calls" if openai_tool_calls else "stop"
# Extract content from decoded message
content = ""
if isinstance(decoded_message.content, str):
content = decoded_message.content
elif isinstance(decoded_message.content, list):
for item in decoded_message.content:
if isinstance(item, RawTextItem):
content += item.text
# Create OpenAI response
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
return OpenAIChatCompletion(
id=response_id,
object="chat.completion",
created=created,
model=params.model,
choices=[
OpenAIChoice(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant",
content=content,
tool_calls=openai_tool_calls,
),
finish_reason=finish_reason,
logprobs=None,
)
],
usage=OpenAIChatCompletionUsage(
prompt_tokens=0, # TODO: calculate properly
completion_tokens=0, # TODO: calculate properly
total_tokens=0, # TODO: calculate properly
),
)
async def _stream_chat_completion(
self,
generator,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Stream chat completion chunks as they're generated."""
from llama_stack.apis.inference import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoiceDelta,
OpenAIChunkChoice,
)
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
generated_text = ""
# Yield chunks as tokens are generated
for result_batch in generator:
for result in result_batch:
if result.ignore_token or result.source != "output":
continue
generated_text += result.text
# Yield delta chunk with the new text
chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(
role="assistant",
content=result.text,
),
finish_reason="",
logprobs=None,
)
],
)
yield chunk
# After generation completes, decode the full message to extract tool calls
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
# If tool calls are present, yield a final chunk with tool_calls
if decoded_message.tool_calls:
openai_tool_calls = [
OpenAIChatCompletionToolCall(
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
id=f"call_{uuid.uuid4().hex[:24]}",
type="function",
function=OpenAIChatCompletionToolCallFunction(
name=str(tc.tool_name),
arguments=tc.arguments,
),
)
for tc in decoded_message.tool_calls
]
# Yield chunk with tool_calls
chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(
role="assistant",
tool_calls=openai_tool_calls,
),
finish_reason="",
logprobs=None,
)
],
)
yield chunk
finish_reason = "tool_calls"
else:
finish_reason = "stop"
# Yield final chunk with finish_reason
final_chunk = OpenAIChatCompletionChunk(
id=response_id,
object="chat.completion.chunk",
created=created,
model=params.model,
choices=[
OpenAIChunkChoice(
index=0,
delta=OpenAIChoiceDelta(),
finish_reason=finish_reason,
logprobs=None,
)
],
)
yield final_chunk

View file

@ -4,17 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Callable, Generator from collections.abc import Callable
from copy import deepcopy
from functools import partial from functools import partial
from typing import Any from typing import Any
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .parallel_utils import ModelParallelProcessGroup from .parallel_utils import ModelParallelProcessGroup
@ -23,12 +18,14 @@ class ModelRunner:
def __init__(self, llama): def __init__(self, llama):
self.llama = llama self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: Any): def __call__(self, task: Any):
if task[0] == "chat_completion": task_type = task[0]
return self.llama.chat_completion(task[1]) if task_type == "chat_completion":
# task[1] is [params, raw_messages]
params, raw_messages = task[1]
return self.llama.chat_completion(params, raw_messages)
else: else:
raise ValueError(f"Unexpected task type {task[0]}") raise ValueError(f"Unexpected task type {task_type}")
def init_model_cb( def init_model_cb(
@ -78,19 +75,3 @@ class LlamaModelParallelGenerator:
def __exit__(self, exc_type, exc_value, exc_traceback): def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop() self.group.stop()
def completion(
self,
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("completion", req_obj))
yield from gen
def chat_completion(
self,
request_batch: list[ChatCompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("chat_completion", req_obj))
yield from gen

View file

@ -33,10 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import GenerationResult from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
log = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference")
@ -69,10 +65,7 @@ class CancelSentinel(BaseModel):
class TaskRequest(BaseModel): class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: tuple[ task: tuple[str, list]
str,
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
]
class TaskResponse(BaseModel): class TaskResponse(BaseModel):
@ -328,10 +321,7 @@ class ModelParallelProcessGroup:
def run_inference( def run_inference(
self, self,
req: tuple[ req: tuple[str, list],
str,
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
],
) -> Generator: ) -> Generator:
assert not self.running, "inference already running" assert not self.running, "inference already running"

View file

@ -22,9 +22,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
) )
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
)
from .config import SentenceTransformersInferenceConfig from .config import SentenceTransformersInferenceConfig
@ -32,7 +29,6 @@ log = get_logger(name=__name__, category="inference")
class SentenceTransformersInferenceImpl( class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
InferenceProvider, InferenceProvider,
ModelsProtocolPrivate, ModelsProtocolPrivate,

View file

@ -297,6 +297,20 @@ Available Models:
Azure OpenAI inference provider for accessing GPT models and other Azure services. Azure OpenAI inference provider for accessing GPT models and other Azure services.
Provider documentation Provider documentation
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
""",
),
RemoteProviderSpec(
api=Api.inference,
provider_type="remote::oci",
adapter_type="oci",
pip_packages=["oci"],
module="llama_stack.providers.remote.inference.oci",
config_class="llama_stack.providers.remote.inference.oci.config.OCIConfig",
provider_data_validator="llama_stack.providers.remote.inference.oci.config.OCIProviderDataValidator",
description="""
Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models.
Provider documentation
https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm
""", """,
), ),
] ]

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import OCIConfig
async def get_adapter_impl(config: OCIConfig, _deps) -> InferenceProvider:
from .oci import OCIInferenceAdapter
adapter = OCIInferenceAdapter(config=config)
await adapter.initialize()
return adapter

View file

@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Generator, Mapping
from typing import Any, override
import httpx
import oci
import requests
from oci.config import DEFAULT_LOCATION, DEFAULT_PROFILE
OciAuthSigner = type[oci.signer.AbstractBaseSigner]
class HttpxOciAuth(httpx.Auth):
"""
Custom HTTPX authentication class that implements OCI request signing.
This class handles the authentication flow for HTTPX requests by signing them
using the OCI Signer, which adds the necessary authentication headers for
OCI API calls.
Attributes:
signer (oci.signer.Signer): The OCI signer instance used for request signing
"""
def __init__(self, signer: OciAuthSigner):
self.signer = signer
@override
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]:
# Read the request content to handle streaming requests properly
try:
content = request.content
except httpx.RequestNotRead:
# For streaming requests, we need to read the content first
content = request.read()
req = requests.Request(
method=request.method,
url=str(request.url),
headers=dict(request.headers),
data=content,
)
prepared_request = req.prepare()
# Sign the request using the OCI Signer
self.signer.do_request_sign(prepared_request) # type: ignore
# Update the original HTTPX request with the signed headers
request.headers.update(prepared_request.headers)
yield request
class OciInstancePrincipalAuth(HttpxOciAuth):
def __init__(self, **kwargs: Mapping[str, Any]):
self.signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner(**kwargs)
class OciUserPrincipalAuth(HttpxOciAuth):
def __init__(self, config_file: str = DEFAULT_LOCATION, profile_name: str = DEFAULT_PROFILE):
config = oci.config.from_file(config_file, profile_name)
oci.config.validate_config(config) # type: ignore
key_content = ""
with open(config["key_file"]) as f:
key_content = f.read()
self.signer = oci.signer.Signer(
tenancy=config["tenancy"],
user=config["user"],
fingerprint=config["fingerprint"],
private_key_file_location=config.get("key_file"),
pass_phrase="none", # type: ignore
private_key_content=key_content,
)

View file

@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class OCIProviderDataValidator(BaseModel):
oci_auth_type: str = Field(
description="OCI authentication type (must be one of: instance_principal, config_file)",
)
oci_region: str = Field(
description="OCI region (e.g., us-ashburn-1)",
)
oci_compartment_id: str = Field(
description="OCI compartment ID for the Generative AI service",
)
oci_config_file_path: str | None = Field(
default="~/.oci/config",
description="OCI config file path (required if oci_auth_type is config_file)",
)
oci_config_profile: str | None = Field(
default="DEFAULT",
description="OCI config profile (required if oci_auth_type is config_file)",
)
@json_schema_type
class OCIConfig(RemoteInferenceProviderConfig):
oci_auth_type: str = Field(
description="OCI authentication type (must be one of: instance_principal, config_file)",
default_factory=lambda: os.getenv("OCI_AUTH_TYPE", "instance_principal"),
)
oci_region: str = Field(
default_factory=lambda: os.getenv("OCI_REGION", "us-ashburn-1"),
description="OCI region (e.g., us-ashburn-1)",
)
oci_compartment_id: str = Field(
default_factory=lambda: os.getenv("OCI_COMPARTMENT_OCID", ""),
description="OCI compartment ID for the Generative AI service",
)
oci_config_file_path: str = Field(
default_factory=lambda: os.getenv("OCI_CONFIG_FILE_PATH", "~/.oci/config"),
description="OCI config file path (required if oci_auth_type is config_file)",
)
oci_config_profile: str = Field(
default_factory=lambda: os.getenv("OCI_CLI_PROFILE", "DEFAULT"),
description="OCI config profile (required if oci_auth_type is config_file)",
)
@classmethod
def sample_run_config(
cls,
oci_auth_type: str = "${env.OCI_AUTH_TYPE:=instance_principal}",
oci_config_file_path: str = "${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}",
oci_config_profile: str = "${env.OCI_CLI_PROFILE:=DEFAULT}",
oci_region: str = "${env.OCI_REGION:=us-ashburn-1}",
oci_compartment_id: str = "${env.OCI_COMPARTMENT_OCID:=}",
**kwargs,
) -> dict[str, Any]:
return {
"oci_auth_type": oci_auth_type,
"oci_config_file_path": oci_config_file_path,
"oci_config_profile": oci_config_profile,
"oci_region": oci_region,
"oci_compartment_id": oci_compartment_id,
}

View file

@ -0,0 +1,140 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Iterable
from typing import Any
import httpx
import oci
from oci.generative_ai.generative_ai_client import GenerativeAiClient
from oci.generative_ai.models import ModelCollection
from openai._base_client import DefaultAsyncHttpxClient
from llama_stack.apis.inference.inference import (
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth
from llama_stack.providers.remote.inference.oci.config import OCIConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
logger = get_logger(name=__name__, category="inference::oci")
OCI_AUTH_TYPE_INSTANCE_PRINCIPAL = "instance_principal"
OCI_AUTH_TYPE_CONFIG_FILE = "config_file"
VALID_OCI_AUTH_TYPES = [OCI_AUTH_TYPE_INSTANCE_PRINCIPAL, OCI_AUTH_TYPE_CONFIG_FILE]
DEFAULT_OCI_REGION = "us-ashburn-1"
MODEL_CAPABILITIES = ["TEXT_GENERATION", "TEXT_SUMMARIZATION", "TEXT_EMBEDDINGS", "CHAT"]
class OCIInferenceAdapter(OpenAIMixin):
config: OCIConfig
async def initialize(self) -> None:
"""Initialize and validate OCI configuration."""
if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES:
raise ValueError(
f"Invalid OCI authentication type: {self.config.oci_auth_type}."
f"Valid types are one of: {VALID_OCI_AUTH_TYPES}"
)
if not self.config.oci_compartment_id:
raise ValueError("OCI_COMPARTMENT_OCID is a required parameter. Either set in env variable or config.")
def get_base_url(self) -> str:
region = self.config.oci_region or DEFAULT_OCI_REGION
return f"https://inference.generativeai.{region}.oci.oraclecloud.com/20231130/actions/v1"
def get_api_key(self) -> str | None:
# OCI doesn't use API keys, it uses request signing
return "<NOTUSED>"
def get_extra_client_params(self) -> dict[str, Any]:
"""
Get extra parameters for the AsyncOpenAI client, including OCI-specific auth and headers.
"""
auth = self._get_auth()
compartment_id = self.config.oci_compartment_id or ""
return {
"http_client": DefaultAsyncHttpxClient(
auth=auth,
headers={
"CompartmentId": compartment_id,
},
),
}
def _get_oci_signer(self) -> oci.signer.AbstractBaseSigner | None:
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
return oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
return None
def _get_oci_config(self) -> dict:
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
config = {"region": self.config.oci_region}
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
config = oci.config.from_file(self.config.oci_config_file_path, self.config.oci_config_profile)
if not config.get("region"):
raise ValueError(
"Region not specified in config. Please specify in config or with OCI_REGION env variable."
)
return config
def _get_auth(self) -> httpx.Auth:
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
return OciInstancePrincipalAuth()
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
return OciUserPrincipalAuth(
config_file=self.config.oci_config_file_path, profile_name=self.config.oci_config_profile
)
else:
raise ValueError(f"Invalid OCI authentication type: {self.config.oci_auth_type}")
async def list_provider_model_ids(self) -> Iterable[str]:
"""
List available models from OCI Generative AI service.
"""
oci_config = self._get_oci_config()
oci_signer = self._get_oci_signer()
compartment_id = self.config.oci_compartment_id or ""
if oci_signer is None:
client = GenerativeAiClient(config=oci_config)
else:
client = GenerativeAiClient(config=oci_config, signer=oci_signer)
models: ModelCollection = client.list_models(
compartment_id=compartment_id, capability=MODEL_CAPABILITIES, lifecycle_state="ACTIVE"
).data
seen_models = set()
model_ids = []
for model in models.items:
if model.time_deprecated or model.time_on_demand_retired:
continue
if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities:
continue
# Use display_name + model_type as the key to avoid conflicts
model_key = (model.display_name, ModelType.llm)
if model_key in seen_models:
continue
seen_models.add(model_key)
model_ids.append(model.display_name)
return model_ids
async def openai_embeddings(self, params: OpenAIEmbeddingsRequestWithExtraBody) -> OpenAIEmbeddingsResponse:
# The constructed url is a mask that hits OCI's "chat" action, which is not supported for embeddings.
raise NotImplementedError("OCI Provider does not (currently) support embeddings")

View file

@ -11,9 +11,7 @@ from collections.abc import AsyncIterator
import litellm import litellm
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest,
InferenceProvider, InferenceProvider,
JsonSchemaResponseFormat,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody, OpenAIChatCompletionRequestWithExtraBody,
@ -23,15 +21,11 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
ToolChoice,
) )
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict_new,
convert_tooldef_to_openai_tool,
get_sampling_options,
prepare_openai_completion_params, prepare_openai_completion_params,
) )
@ -127,51 +121,6 @@ class LiteLLMOpenAIMixin(
return schema return schema
async def _get_params(self, request: ChatCompletionRequest) -> dict:
from typing import Any
input_dict: dict[str, Any] = {}
input_dict["messages"] = [
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
]
if fmt := request.response_format:
if not isinstance(fmt, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
# Convert to dict for manipulation
fmt_dict = dict(fmt.json_schema)
name = fmt_dict["title"]
del fmt_dict["title"]
fmt_dict["additionalProperties"] = False
# Apply additionalProperties: False recursively to all objects
fmt_dict = self._add_additional_properties_recursive(fmt_dict)
input_dict["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": name,
"schema": fmt_dict,
"strict": self.json_schema_strict,
},
}
if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config and (tool_choice := request.tool_config.tool_choice):
input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice
return {
"model": request.model,
"api_key": self.get_api_key(),
"api_base": self.api_base,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
def get_api_key(self) -> str: def get_api_key(self) -> str:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field key_field = self.provider_data_api_key_field

File diff suppressed because it is too large Load diff

View file

@ -21,19 +21,18 @@ from llama_stack.apis.common.content_types import (
TextContentItem, TextContentItem,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest, CompletionRequest,
Message, OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam, OpenAIChatCompletionContentPartTextParam,
OpenAIFile, OpenAIFile,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SystemMessage,
SystemMessageBehavior,
ToolChoice, ToolChoice,
ToolDefinition,
UserMessage,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
@ -42,33 +41,19 @@ from llama_stack.models.llama.datatypes import (
RawMediaItem, RawMediaItem,
RawMessage, RawMessage,
RawTextItem, RawTextItem,
Role,
StopReason, StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
PythonListCustomToolGenerator,
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models
log = get_logger(name=__name__, category="providers::utils") log = get_logger(name=__name__, category="providers::utils")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
messages: list[RawMessage]
class CompletionRequestWithRawContent(CompletionRequest): class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent content: RawContent
@ -103,28 +88,6 @@ def interleaved_content_as_str(
return _process(content) return _process(content)
async def convert_request_to_raw(
request: ChatCompletionRequest | CompletionRequest,
) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent:
if isinstance(request, ChatCompletionRequest):
messages = []
for m in request.messages:
content = await interleaved_content_convert_to_raw(m.content)
d = m.model_dump()
d["content"] = content
messages.append(RawMessage(**d))
d = request.model_dump()
d["messages"] = messages
request = ChatCompletionRequestWithRawContent(**d)
else:
d = request.model_dump()
d["content"] = await interleaved_content_convert_to_raw(request.content)
request = CompletionRequestWithRawContent(**d)
return request
async def interleaved_content_convert_to_raw( async def interleaved_content_convert_to_raw(
content: InterleavedContent, content: InterleavedContent,
) -> RawContent: ) -> RawContent:
@ -171,6 +134,36 @@ async def interleaved_content_convert_to_raw(
return await _localize_single(content) return await _localize_single(content)
async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage:
"""Convert OpenAI message format to RawMessage format used by Llama formatters."""
if isinstance(message, OpenAIUserMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="user", content=content)
elif isinstance(message, OpenAISystemMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="system", content=content)
elif isinstance(message, OpenAIAssistantMessageParam):
content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type]
tool_calls = []
if message.tool_calls:
for tc in message.tool_calls:
if tc.function:
tool_calls.append(
ToolCall(
call_id=tc.id or "",
tool_name=tc.function.name or "",
arguments=tc.function.arguments or "{}",
)
)
return RawMessage(role="assistant", content=content, tool_calls=tool_calls)
elif isinstance(message, OpenAIToolMessageParam):
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
return RawMessage(role="tool", content=content)
else:
# Handle OpenAIDeveloperMessageParam if needed
raise ValueError(f"Unsupported message type: {type(message)}")
def content_has_media(content: InterleavedContent): def content_has_media(content: InterleavedContent):
def _has_media_content(c): def _has_media_content(c):
return isinstance(c, ImageContentItem) return isinstance(c, ImageContentItem)
@ -181,17 +174,6 @@ def content_has_media(content: InterleavedContent):
return _has_media_content(content) return _has_media_content(content)
def messages_have_media(messages: list[Message]):
return any(content_has_media(m.content) for m in messages)
def request_has_media(request: ChatCompletionRequest | CompletionRequest):
if isinstance(request, ChatCompletionRequest):
return messages_have_media(request.messages)
else:
return content_has_media(request.content)
async def localize_image_content(uri: str) -> tuple[bytes, str] | None: async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
if uri.startswith("http"): if uri.startswith("http"):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
@ -253,79 +235,6 @@ def augment_content_with_response_format_prompt(response_format, content):
return content return content
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
)
return formatter.tokenizer.decode(model_input.tokens)
async def chat_completion_request_to_model_input_info(
request: ChatCompletionRequest, llama_model: str
) -> tuple[str, int]:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
)
return (
formatter.tokenizer.decode(model_input.tokens),
len(model_input.tokens),
)
def chat_completion_request_to_messages(
request: ChatCompletionRequest,
llama_model: str,
) -> list[Message]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
"""
assert llama_model is not None, "llama_model is required"
model = resolve_model(llama_model)
if model is None:
log.error(f"Could not resolve model {llama_model}")
return request.messages
allowed_models = supported_inference_models()
descriptors = [m.descriptor() for m in allowed_models]
if model.descriptor() not in descriptors:
log.error(f"Unsupported inference model? {model.descriptor()}")
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request)
elif model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
):
# llama3.2, llama3.3 follow the same tool prompt format
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
elif model.model_family == ModelFamily.llama4:
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
else:
messages = request.messages
if fmt_prompt := response_format_prompt(request.response_format):
messages.append(UserMessage(content=fmt_prompt))
return messages
def response_format_prompt(fmt: ResponseFormat | None): def response_format_prompt(fmt: ResponseFormat | None):
if not fmt: if not fmt:
return None return None
@ -338,128 +247,6 @@ def response_format_prompt(fmt: ResponseFormat | None):
raise ValueError(f"Unknown response format {fmt.type}") raise ValueError(f"Unknown response format {fmt.type}")
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if request.tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(request.tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if existing_system_message:
# TODO: this fn is needed in many places
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
sys_content += "\n"
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
messages.append(SystemMessage(content=sys_content))
has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif fmt == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
# Add back existing messages from the request
messages += existing_messages
return messages
def augment_messages_for_tools_llama(
request: ChatCompletionRequest,
custom_tool_prompt_generator,
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
sys_content = ""
custom_tools, builtin_tools = [], []
for t in request.tools:
if isinstance(t.tool_name, str):
custom_tools.append(t)
else:
builtin_tools.append(t)
if builtin_tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(builtin_tools)
sys_content += tool_template.render()
sys_content += "\n"
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
system_prompt = None
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
system_prompt = existing_system_message.content
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
sys_content += tool_template.render()
sys_content += "\n"
if existing_system_message and (
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
):
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages]
return messages
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str: def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
if tool_choice == ToolChoice.auto: if tool_choice == ToolChoice.auto:
return "" return ""

View file

@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
VectorStoreContent, VectorStoreContent,
VectorStoreDeleteResponse, VectorStoreDeleteResponse,
VectorStoreFileBatchObject, VectorStoreFileBatchObject,
VectorStoreFileContentsResponse, VectorStoreFileContentResponse,
VectorStoreFileCounts, VectorStoreFileCounts,
VectorStoreFileDeleteResponse, VectorStoreFileDeleteResponse,
VectorStoreFileLastError, VectorStoreFileLastError,
@ -921,22 +921,21 @@ class OpenAIVectorStoreMixin(ABC):
self, self,
vector_store_id: str, vector_store_id: str,
file_id: str, file_id: str,
) -> VectorStoreFileContentsResponse: ) -> VectorStoreFileContentResponse:
"""Retrieves the contents of a vector store file.""" """Retrieves the contents of a vector store file."""
if vector_store_id not in self.openai_vector_stores: if vector_store_id not in self.openai_vector_stores:
raise VectorStoreNotFoundError(vector_store_id) raise VectorStoreNotFoundError(vector_store_id)
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
chunks = [Chunk.model_validate(c) for c in dict_chunks] chunks = [Chunk.model_validate(c) for c in dict_chunks]
content = [] content = []
for chunk in chunks: for chunk in chunks:
content.extend(self._chunk_to_vector_store_content(chunk)) content.extend(self._chunk_to_vector_store_content(chunk))
return VectorStoreFileContentsResponse( return VectorStoreFileContentResponse(
file_id=file_id, object="vector_store.file_content.page",
filename=file_info.get("filename", ""), data=content,
attributes=file_info.get("attributes", {}), has_more=False,
content=content, next_page=None,
) )
async def openai_update_vector_store_file( async def openai_update_vector_store_file(

View file

@ -516,3 +516,169 @@ def test_response_with_instructions(openai_client, client_with_models, text_mode
# Verify instructions from previous response was not carried over to the next response # Verify instructions from previous response was not carried over to the next response
assert response_with_instructions2.instructions == instructions2 assert response_with_instructions2.instructions == instructions2
@pytest.mark.skip(reason="Tool calling is not reliable.")
def test_max_tool_calls_with_function_tools(openai_client, client_with_models, text_model_id):
"""Test handling of max_tool_calls with function tools in responses."""
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
client = openai_client
max_tool_calls = 1
tools = [
{
"type": "function",
"name": "get_weather",
"description": "Get weather information for a specified location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name (e.g., 'New York', 'London')",
},
},
},
},
{
"type": "function",
"name": "get_time",
"description": "Get current time for a specified location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name (e.g., 'New York', 'London')",
},
},
},
},
]
# First create a response that triggers function tools
response = client.responses.create(
model=text_model_id,
input="Can you tell me the weather in Paris and the current time?",
tools=tools,
stream=False,
max_tool_calls=max_tool_calls,
)
# Verify we got two function calls and that the max_tool_calls do not affect function tools
assert len(response.output) == 2
assert response.output[0].type == "function_call"
assert response.output[0].name == "get_weather"
assert response.output[0].status == "completed"
assert response.output[1].type == "function_call"
assert response.output[1].name == "get_time"
assert response.output[0].status == "completed"
# Verify we have a valid max_tool_calls field
assert response.max_tool_calls == max_tool_calls
def test_max_tool_calls_invalid(openai_client, client_with_models, text_model_id):
"""Test handling of invalid max_tool_calls in responses."""
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
client = openai_client
input = "Search for today's top technology news."
invalid_max_tool_calls = 0
tools = [
{"type": "web_search"},
]
# Create a response with an invalid max_tool_calls value i.e. 0
# Handle ValueError from LLS and BadRequestError from OpenAI client
with pytest.raises((ValueError, BadRequestError)) as excinfo:
client.responses.create(
model=text_model_id,
input=input,
tools=tools,
stream=False,
max_tool_calls=invalid_max_tool_calls,
)
error_message = str(excinfo.value)
assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, (
f"Expected error message about invalid max_tool_calls, got: {error_message}"
)
def test_max_tool_calls_with_builtin_tools(openai_client, client_with_models, text_model_id):
"""Test handling of max_tool_calls with built-in tools in responses."""
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
client = openai_client
input = "Search for today's top technology and a positive news story. You MUST make exactly two separate web search calls."
max_tool_calls = [1, 5]
tools = [
{"type": "web_search"},
]
# First create a response that triggers web_search tools without max_tool_calls
response = client.responses.create(
model=text_model_id,
input=input,
tools=tools,
stream=False,
)
# Verify we got two web search calls followed by a message
assert len(response.output) == 3
assert response.output[0].type == "web_search_call"
assert response.output[0].status == "completed"
assert response.output[1].type == "web_search_call"
assert response.output[1].status == "completed"
assert response.output[2].type == "message"
assert response.output[2].status == "completed"
assert response.output[2].role == "assistant"
# Next create a response that triggers web_search tools with max_tool_calls set to 1
response_2 = client.responses.create(
model=text_model_id,
input=input,
tools=tools,
stream=False,
max_tool_calls=max_tool_calls[0],
)
# Verify we got one web search tool call followed by a message
assert len(response_2.output) == 2
assert response_2.output[0].type == "web_search_call"
assert response_2.output[0].status == "completed"
assert response_2.output[1].type == "message"
assert response_2.output[1].status == "completed"
assert response_2.output[1].role == "assistant"
# Verify we have a valid max_tool_calls field
assert response_2.max_tool_calls == max_tool_calls[0]
# Finally create a response that triggers web_search tools with max_tool_calls set to 5
response_3 = client.responses.create(
model=text_model_id,
input=input,
tools=tools,
stream=False,
max_tool_calls=max_tool_calls[1],
)
# Verify we got two web search calls followed by a message
assert len(response_3.output) == 3
assert response_3.output[0].type == "web_search_call"
assert response_3.output[0].status == "completed"
assert response_3.output[1].type == "web_search_call"
assert response_3.output[1].status == "completed"
assert response_3.output[2].type == "message"
assert response_3.output[2].status == "completed"
assert response_3.output[2].role == "assistant"
# Verify we have a valid max_tool_calls field
assert response_3.max_tool_calls == max_tool_calls[1]

View file

@ -54,6 +54,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
# {"error":{"message":"Unknown request URL: GET /openai/v1/completions. Please check the URL for typos, # {"error":{"message":"Unknown request URL: GET /openai/v1/completions. Please check the URL for typos,
# or see the docs at https://console.groq.com/docs/","type":"invalid_request_error","code":"unknown_url"}} # or see the docs at https://console.groq.com/docs/","type":"invalid_request_error","code":"unknown_url"}}
"remote::groq", "remote::groq",
"remote::oci",
"remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404 "remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404
"remote::anthropic", # at least claude-3-{5,7}-{haiku,sonnet}-* / claude-{sonnet,opus}-4-* are not supported "remote::anthropic", # at least claude-3-{5,7}-{haiku,sonnet}-* / claude-{sonnet,opus}-4-* are not supported
"remote::azure", # {'error': {'code': 'OperationNotSupported', 'message': 'The completion operation "remote::azure", # {'error': {'code': 'OperationNotSupported', 'message': 'The completion operation

View file

@ -138,6 +138,7 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
"remote::runpod", "remote::runpod",
"remote::sambanova", "remote::sambanova",
"remote::tgi", "remote::tgi",
"remote::oci",
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")

View file

@ -907,16 +907,16 @@ def test_openai_vector_store_retrieve_file_contents(
) )
assert file_contents is not None assert file_contents is not None
assert len(file_contents.content) == 1 assert file_contents.object == "vector_store.file_content.page"
content = file_contents.content[0] assert len(file_contents.data) == 1
content = file_contents.data[0]
# llama-stack-client returns a model, openai-python is a badboy and returns a dict # llama-stack-client returns a model, openai-python is a badboy and returns a dict
if not isinstance(content, dict): if not isinstance(content, dict):
content = content.model_dump() content = content.model_dump()
assert content["type"] == "text" assert content["type"] == "text"
assert content["text"] == test_content.decode("utf-8") assert content["text"] == test_content.decode("utf-8")
assert file_contents.filename == file_name assert file_contents.has_more is False
assert file_contents.attributes == attributes
@vector_provider_wrapper @vector_provider_wrapper
@ -1483,14 +1483,12 @@ def test_openai_vector_store_file_batch_retrieve_contents(
) )
assert file_contents is not None assert file_contents is not None
assert file_contents.filename == file_data[i][0] assert file_contents.object == "vector_store.file_content.page"
assert len(file_contents.content) > 0 assert len(file_contents.data) > 0
# Verify the content matches what we uploaded # Verify the content matches what we uploaded
content_text = ( content_text = (
file_contents.content[0].text file_contents.data[0].text if hasattr(file_contents.data[0], "text") else file_contents.data[0]["text"]
if hasattr(file_contents.content[0], "text")
else file_contents.content[0]["text"]
) )
assert file_data[i][1].decode("utf-8") in content_text assert file_data[i][1].decode("utf-8") in content_text

View file

@ -1,303 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionMessage,
StopReason,
SystemMessage,
SystemMessageBehavior,
ToolCall,
ToolConfig,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
chat_completion_request_to_prompt,
interleaved_content_as_str,
)
MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct"
async def test_system_default():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
async def test_system_builtin_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
async def test_system_custom_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
async def test_system_custom_and_builtin():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
async def test_completion_message_encoding():
request = ChatCompletionRequest(
model=MODEL3_2,
messages=[
UserMessage(content="hello"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
tool_name="custom1",
arguments='{"param1": "value1"}', # arguments must be a JSON string
call_id="123",
)
],
),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '[custom1(param1="value1")]' in prompt
request.model = MODEL
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
async def test_user_provided_system_message():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert messages[-1].content == content
async def test_replace_system_message_behavior_builtin_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
# function description is present in the system prompt
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import Mock
import pytest
from llama_stack.providers.inline.inference.meta_reference.model_parallel import (
ModelRunner,
)
class TestModelRunner:
"""Test ModelRunner task dispatching for model-parallel inference."""
def test_chat_completion_task_dispatch(self):
"""Verify ModelRunner correctly dispatches chat_completion tasks."""
# Create a mock generator
mock_generator = Mock()
mock_generator.chat_completion = Mock(return_value=iter([]))
runner = ModelRunner(mock_generator)
# Create a chat_completion task
fake_params = {"model": "test"}
fake_messages = [{"role": "user", "content": "test"}]
task = ("chat_completion", [fake_params, fake_messages])
# Execute task
runner(task)
# Verify chat_completion was called with correct arguments
mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages)
def test_invalid_task_type_raises_error(self):
"""Verify ModelRunner rejects invalid task types."""
mock_generator = Mock()
runner = ModelRunner(mock_generator)
with pytest.raises(ValueError, match="Unexpected task type"):
runner(("invalid_task", []))

View file

@ -10,11 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from llama_stack.apis.inference import CompletionMessage, UserMessage from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
@ -136,11 +138,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
# Run the shield # Run the shield
messages = [ messages = [
UserMessage(role="user", content="Hello, how are you?"), OpenAIUserMessageParam(content="Hello, how are you?"),
CompletionMessage( OpenAIAssistantMessageParam(
role="assistant",
content="I'm doing well, thank you for asking!", content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[], tool_calls=[],
), ),
] ]
@ -191,13 +191,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
# Mock Guardrails API response # Mock Guardrails API response
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}} mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
# Run the shield
messages = [ messages = [
UserMessage(role="user", content="Hello, how are you?"), OpenAIUserMessageParam(content="Hello, how are you?"),
CompletionMessage( OpenAIAssistantMessageParam(
role="assistant",
content="I'm doing well, thank you for asking!", content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[], tool_calls=[],
), ),
] ]
@ -243,7 +240,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
adapter.shield_store.get_shield.return_value = None adapter.shield_store.get_shield.return_value = None
messages = [ messages = [
UserMessage(role="user", content="Hello, how are you?"), OpenAIUserMessageParam(content="Hello, how are you?"),
] ]
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -274,11 +271,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
# Running the shield should raise an exception # Running the shield should raise an exception
messages = [ messages = [
UserMessage(role="user", content="Hello, how are you?"), OpenAIUserMessageParam(content="Hello, how are you?"),
CompletionMessage( OpenAIAssistantMessageParam(
role="assistant",
content="I'm doing well, thank you for asking!", content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[], tool_calls=[],
), ),
] ]

View file

@ -1,220 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from pydantic import ValidationError
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference import (
CompletionMessage,
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
SystemMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
convert_message_to_openai_dict_new,
openai_messages_to_messages,
)
async def test_convert_message_to_openai_dict():
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
assert await convert_message_to_openai_dict(message) == {
"role": "user",
"content": [{"type": "text", "text": "Hello, world!"}],
}
# Test convert_message_to_openai_dict with a tool call
async def test_convert_message_to_openai_dict_with_tool_call():
message = CompletionMessage(
content="",
tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')],
stop_reason=StopReason.end_of_turn,
)
openai_dict = await convert_message_to_openai_dict(message)
assert openai_dict == {
"role": "assistant",
"content": [{"type": "text", "text": ""}],
"tool_calls": [
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
],
}
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
message = CompletionMessage(
content="",
tool_calls=[
ToolCall(
call_id="123",
tool_name=BuiltinTool.brave_search,
arguments='{"foo": "bar"}',
)
],
stop_reason=StopReason.end_of_turn,
)
openai_dict = await convert_message_to_openai_dict(message)
assert openai_dict == {
"role": "assistant",
"content": [{"type": "text", "text": ""}],
"tool_calls": [
{"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}}
],
}
async def test_openai_messages_to_messages_with_content_str():
openai_messages = [
OpenAISystemMessageParam(content="system message"),
OpenAIUserMessageParam(content="user message"),
OpenAIAssistantMessageParam(content="assistant message"),
]
llama_messages = openai_messages_to_messages(openai_messages)
assert len(llama_messages) == 3
assert isinstance(llama_messages[0], SystemMessage)
assert isinstance(llama_messages[1], UserMessage)
assert isinstance(llama_messages[2], CompletionMessage)
assert llama_messages[0].content == "system message"
assert llama_messages[1].content == "user message"
assert llama_messages[2].content == "assistant message"
async def test_openai_messages_to_messages_with_content_list():
openai_messages = [
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]),
OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]),
]
llama_messages = openai_messages_to_messages(openai_messages)
assert len(llama_messages) == 3
assert isinstance(llama_messages[0], SystemMessage)
assert isinstance(llama_messages[1], UserMessage)
assert isinstance(llama_messages[2], CompletionMessage)
assert llama_messages[0].content[0].text == "system message"
assert llama_messages[1].content[0].text == "user message"
assert llama_messages[2].content[0].text == "assistant message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_string(message_class, kwargs):
"""Test that messages accept string text content."""
msg = message_class(content="Test message", **kwargs)
assert msg.content == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_list(message_class, kwargs):
"""Test that messages accept list of text content parts."""
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
msg = message_class(content=content_list, **kwargs)
assert len(msg.content) == 1
assert msg.content[0].text == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_rejects_images(message_class, kwargs):
"""Test that system, assistant, developer, and tool messages reject image content."""
with pytest.raises(ValidationError):
message_class(
content=[
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
],
**kwargs,
)
def test_user_message_accepts_images():
"""Test that user messages accept image content (unlike other message types)."""
# List with images should work
msg = OpenAIUserMessageParam(
content=[
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
]
)
assert len(msg.content) == 2
assert msg.content[0].text == "Describe this image:"
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
async def test_convert_message_to_openai_dict_new_user_message():
"""Test convert_message_to_openai_dict_new with UserMessage."""
message = UserMessage(content="Hello, world!", role="user")
result = await convert_message_to_openai_dict_new(message)
assert result["role"] == "user"
assert result["content"] == "Hello, world!"
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
message = CompletionMessage(
content="I'll help you find the weather.",
tool_calls=[
ToolCall(
call_id="call_123",
tool_name="get_weather",
arguments='{"city": "Sligo"}',
)
],
stop_reason=StopReason.end_of_turn,
)
result = await convert_message_to_openai_dict_new(message)
# This would have failed with "Cannot instantiate typing.Union" before the fix
assert result["role"] == "assistant"
assert result["content"] == "I'll help you find the weather."
assert "tool_calls" in result
assert result["tool_calls"] is not None
assert len(result["tool_calls"]) == 1
tool_call = result["tool_calls"][0]
assert tool_call.id == "call_123"
assert tool_call.type == "function"
assert tool_call.function.name == "get_weather"
assert tool_call.function.arguments == '{"city": "Sligo"}'

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.models.llama.datatypes import RawTextItem
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_openai_message_to_raw_message,
)
class TestConvertOpenAIMessageToRawMessage:
"""Test conversion of OpenAI message types to RawMessage format."""
async def test_user_message_conversion(self):
msg = OpenAIUserMessageParam(role="user", content="Hello world")
raw_msg = await convert_openai_message_to_raw_message(msg)
assert raw_msg.role == "user"
assert isinstance(raw_msg.content, RawTextItem)
assert raw_msg.content.text == "Hello world"
async def test_assistant_message_conversion(self):
msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!")
raw_msg = await convert_openai_message_to_raw_message(msg)
assert raw_msg.role == "assistant"
assert isinstance(raw_msg.content, RawTextItem)
assert raw_msg.content.text == "Hi there!"
assert raw_msg.tool_calls == []