Merge branch 'main' into chore_build

This commit is contained in:
Wen Zhou 2025-06-26 13:30:43 +02:00 committed by GitHub
commit 604e42c56d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
196 changed files with 2332 additions and 1515 deletions

View file

@ -11,6 +11,8 @@ on:
- 'llama_stack/distribution/*.sh'
- '.github/workflows/providers-build.yml'
- 'llama_stack/templates/**'
- 'pyproject.toml'
pull_request:
paths:
- 'llama_stack/cli/stack/build.py'
@ -19,6 +21,7 @@ on:
- 'llama_stack/distribution/*.sh'
- '.github/workflows/providers-build.yml'
- 'llama_stack/templates/**'
- 'pyproject.toml'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}

View file

@ -7390,6 +7390,147 @@
],
"title": "AgentTurnResponseTurnStartPayload"
},
"OpenAIResponseAnnotationCitation": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "url_citation",
"default": "url_citation"
},
"end_index": {
"type": "integer"
},
"start_index": {
"type": "integer"
},
"title": {
"type": "string"
},
"url": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"end_index",
"start_index",
"title",
"url"
],
"title": "OpenAIResponseAnnotationCitation"
},
"OpenAIResponseAnnotationContainerFileCitation": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "container_file_citation",
"default": "container_file_citation"
},
"container_id": {
"type": "string"
},
"end_index": {
"type": "integer"
},
"file_id": {
"type": "string"
},
"filename": {
"type": "string"
},
"start_index": {
"type": "integer"
}
},
"additionalProperties": false,
"required": [
"type",
"container_id",
"end_index",
"file_id",
"filename",
"start_index"
],
"title": "OpenAIResponseAnnotationContainerFileCitation"
},
"OpenAIResponseAnnotationFileCitation": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "file_citation",
"default": "file_citation"
},
"file_id": {
"type": "string"
},
"filename": {
"type": "string"
},
"index": {
"type": "integer"
}
},
"additionalProperties": false,
"required": [
"type",
"file_id",
"filename",
"index"
],
"title": "OpenAIResponseAnnotationFileCitation"
},
"OpenAIResponseAnnotationFilePath": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "file_path",
"default": "file_path"
},
"file_id": {
"type": "string"
},
"index": {
"type": "integer"
}
},
"additionalProperties": false,
"required": [
"type",
"file_id",
"index"
],
"title": "OpenAIResponseAnnotationFilePath"
},
"OpenAIResponseAnnotations": {
"oneOf": [
{
"$ref": "#/components/schemas/OpenAIResponseAnnotationFileCitation"
},
{
"$ref": "#/components/schemas/OpenAIResponseAnnotationCitation"
},
{
"$ref": "#/components/schemas/OpenAIResponseAnnotationContainerFileCitation"
},
{
"$ref": "#/components/schemas/OpenAIResponseAnnotationFilePath"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"file_citation": "#/components/schemas/OpenAIResponseAnnotationFileCitation",
"url_citation": "#/components/schemas/OpenAIResponseAnnotationCitation",
"container_file_citation": "#/components/schemas/OpenAIResponseAnnotationContainerFileCitation",
"file_path": "#/components/schemas/OpenAIResponseAnnotationFilePath"
}
}
},
"OpenAIResponseInput": {
"oneOf": [
{
@ -7764,6 +7905,10 @@
"type": "string",
"const": "web_search"
},
{
"type": "string",
"const": "web_search_preview"
},
{
"type": "string",
"const": "web_search_preview_2025_03_11"
@ -7855,12 +8000,19 @@
"type": "string",
"const": "output_text",
"default": "output_text"
},
"annotations": {
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIResponseAnnotations"
}
}
},
"additionalProperties": false,
"required": [
"text",
"type"
"type",
"annotations"
],
"title": "OpenAIResponseOutputMessageContentOutputText"
},
@ -11190,6 +11342,115 @@
],
"title": "InsertRequest"
},
"Chunk": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"description": "The content of the chunk, which can be interleaved text, images, or other types."
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "Metadata associated with the chunk that will be used in the model context during inference."
},
"embedding": {
"type": "array",
"items": {
"type": "number"
},
"description": "Optional embedding for the chunk. If not provided, it will be computed later."
},
"stored_chunk_id": {
"type": "string",
"description": "The chunk ID that is stored in the vector database. Used for backend functionality."
},
"chunk_metadata": {
"$ref": "#/components/schemas/ChunkMetadata",
"description": "Metadata for the chunk that will NOT be used in the context during inference. The `chunk_metadata` is required backend functionality."
}
},
"additionalProperties": false,
"required": [
"content",
"metadata"
],
"title": "Chunk",
"description": "A chunk of content that can be inserted into a vector database."
},
"ChunkMetadata": {
"type": "object",
"properties": {
"chunk_id": {
"type": "string",
"description": "The ID of the chunk. If not set, it will be generated based on the document ID and content."
},
"document_id": {
"type": "string",
"description": "The ID of the document this chunk belongs to."
},
"source": {
"type": "string",
"description": "The source of the content, such as a URL, file path, or other identifier."
},
"created_timestamp": {
"type": "integer",
"description": "An optional timestamp indicating when the chunk was created."
},
"updated_timestamp": {
"type": "integer",
"description": "An optional timestamp indicating when the chunk was last updated."
},
"chunk_window": {
"type": "string",
"description": "The window of the chunk, which can be used to group related chunks together."
},
"chunk_tokenizer": {
"type": "string",
"description": "The tokenizer used to create the chunk. Default is Tiktoken."
},
"chunk_embedding_model": {
"type": "string",
"description": "The embedding model used to create the chunk's embedding."
},
"chunk_embedding_dimension": {
"type": "integer",
"description": "The dimension of the embedding vector for the chunk."
},
"content_token_count": {
"type": "integer",
"description": "The number of tokens in the content of the chunk."
},
"metadata_token_count": {
"type": "integer",
"description": "The number of tokens in the metadata of the chunk."
}
},
"additionalProperties": false,
"title": "ChunkMetadata",
"description": "`ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that will not be used in the context during inference, but is required for backend functionality. The `ChunkMetadata` is set during chunk creation in `MemoryToolRuntimeImpl().insert()`and is not expected to change after. Use `Chunk.metadata` for metadata that will be used in the context during inference."
},
"InsertChunksRequest": {
"type": "object",
"properties": {
@ -11200,53 +11461,7 @@
"chunks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"description": "The content of the chunk, which can be interleaved text, images, or other types."
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "Metadata associated with the chunk, such as document ID, source, or other relevant information."
},
"embedding": {
"type": "array",
"items": {
"type": "number"
},
"description": "Optional embedding for the chunk. If not provided, it will be computed later."
}
},
"additionalProperties": false,
"required": [
"content",
"metadata"
],
"title": "Chunk",
"description": "A chunk of content that can be inserted into a vector database."
"$ref": "#/components/schemas/Chunk"
},
"description": "The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types. `metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional. If `metadata` is provided, you configure how Llama Stack formats the chunk during generation. If `embedding` is not provided, it will be computed later."
},
@ -14671,53 +14886,7 @@
"chunks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"description": "The content of the chunk, which can be interleaved text, images, or other types."
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "Metadata associated with the chunk, such as document ID, source, or other relevant information."
},
"embedding": {
"type": "array",
"items": {
"type": "number"
},
"description": "Optional embedding for the chunk. If not provided, it will be computed later."
}
},
"additionalProperties": false,
"required": [
"content",
"metadata"
],
"title": "Chunk",
"description": "A chunk of content that can be inserted into a vector database."
"$ref": "#/components/schemas/Chunk"
}
},
"scores": {

View file

@ -5263,6 +5263,106 @@ components:
- event_type
- turn_id
title: AgentTurnResponseTurnStartPayload
OpenAIResponseAnnotationCitation:
type: object
properties:
type:
type: string
const: url_citation
default: url_citation
end_index:
type: integer
start_index:
type: integer
title:
type: string
url:
type: string
additionalProperties: false
required:
- type
- end_index
- start_index
- title
- url
title: OpenAIResponseAnnotationCitation
"OpenAIResponseAnnotationContainerFileCitation":
type: object
properties:
type:
type: string
const: container_file_citation
default: container_file_citation
container_id:
type: string
end_index:
type: integer
file_id:
type: string
filename:
type: string
start_index:
type: integer
additionalProperties: false
required:
- type
- container_id
- end_index
- file_id
- filename
- start_index
title: >-
OpenAIResponseAnnotationContainerFileCitation
OpenAIResponseAnnotationFileCitation:
type: object
properties:
type:
type: string
const: file_citation
default: file_citation
file_id:
type: string
filename:
type: string
index:
type: integer
additionalProperties: false
required:
- type
- file_id
- filename
- index
title: OpenAIResponseAnnotationFileCitation
OpenAIResponseAnnotationFilePath:
type: object
properties:
type:
type: string
const: file_path
default: file_path
file_id:
type: string
index:
type: integer
additionalProperties: false
required:
- type
- file_id
- index
title: OpenAIResponseAnnotationFilePath
OpenAIResponseAnnotations:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseAnnotationFileCitation'
- $ref: '#/components/schemas/OpenAIResponseAnnotationCitation'
- $ref: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation'
- $ref: '#/components/schemas/OpenAIResponseAnnotationFilePath'
discriminator:
propertyName: type
mapping:
file_citation: '#/components/schemas/OpenAIResponseAnnotationFileCitation'
url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation'
container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation'
file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath'
OpenAIResponseInput:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
@ -5488,6 +5588,8 @@ components:
oneOf:
- type: string
const: web_search
- type: string
const: web_search_preview
- type: string
const: web_search_preview_2025_03_11
default: web_search
@ -5547,10 +5649,15 @@ components:
type: string
const: output_text
default: output_text
annotations:
type: array
items:
$ref: '#/components/schemas/OpenAIResponseAnnotations'
additionalProperties: false
required:
- text
- type
- annotations
title: >-
OpenAIResponseOutputMessageContentOutputText
"OpenAIResponseOutputMessageFileSearchToolCall":
@ -7867,6 +7974,107 @@ components:
- vector_db_id
- chunk_size_in_tokens
title: InsertRequest
Chunk:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The content of the chunk, which can be interleaved text, images, or other
types.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
Metadata associated with the chunk that will be used in the model context
during inference.
embedding:
type: array
items:
type: number
description: >-
Optional embedding for the chunk. If not provided, it will be computed
later.
stored_chunk_id:
type: string
description: >-
The chunk ID that is stored in the vector database. Used for backend functionality.
chunk_metadata:
$ref: '#/components/schemas/ChunkMetadata'
description: >-
Metadata for the chunk that will NOT be used in the context during inference.
The `chunk_metadata` is required backend functionality.
additionalProperties: false
required:
- content
- metadata
title: Chunk
description: >-
A chunk of content that can be inserted into a vector database.
ChunkMetadata:
type: object
properties:
chunk_id:
type: string
description: >-
The ID of the chunk. If not set, it will be generated based on the document
ID and content.
document_id:
type: string
description: >-
The ID of the document this chunk belongs to.
source:
type: string
description: >-
The source of the content, such as a URL, file path, or other identifier.
created_timestamp:
type: integer
description: >-
An optional timestamp indicating when the chunk was created.
updated_timestamp:
type: integer
description: >-
An optional timestamp indicating when the chunk was last updated.
chunk_window:
type: string
description: >-
The window of the chunk, which can be used to group related chunks together.
chunk_tokenizer:
type: string
description: >-
The tokenizer used to create the chunk. Default is Tiktoken.
chunk_embedding_model:
type: string
description: >-
The embedding model used to create the chunk's embedding.
chunk_embedding_dimension:
type: integer
description: >-
The dimension of the embedding vector for the chunk.
content_token_count:
type: integer
description: >-
The number of tokens in the content of the chunk.
metadata_token_count:
type: integer
description: >-
The number of tokens in the metadata of the chunk.
additionalProperties: false
title: ChunkMetadata
description: >-
`ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional
information about the chunk that will not be used in the context during
inference, but is required for backend functionality. The `ChunkMetadata` is
set during chunk creation in `MemoryToolRuntimeImpl().insert()`and is not
expected to change after. Use `Chunk.metadata` for metadata that will
be used in the context during inference.
InsertChunksRequest:
type: object
properties:
@ -7877,40 +8085,7 @@ components:
chunks:
type: array
items:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The content of the chunk, which can be interleaved text, images,
or other types.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
Metadata associated with the chunk, such as document ID, source,
or other relevant information.
embedding:
type: array
items:
type: number
description: >-
Optional embedding for the chunk. If not provided, it will be computed
later.
additionalProperties: false
required:
- content
- metadata
title: Chunk
description: >-
A chunk of content that can be inserted into a vector database.
$ref: '#/components/schemas/Chunk'
description: >-
The chunks to insert. Each `Chunk` should contain content which can be
interleaved text, images, or other types. `metadata`: `dict[str, Any]`
@ -10231,40 +10406,7 @@ components:
chunks:
type: array
items:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The content of the chunk, which can be interleaved text, images,
or other types.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
Metadata associated with the chunk, such as document ID, source,
or other relevant information.
embedding:
type: array
items:
type: number
description: >-
Optional embedding for the chunk. If not provided, it will be computed
later.
additionalProperties: false
required:
- content
- metadata
title: Chunk
description: >-
A chunk of content that can be inserted into a vector database.
$ref: '#/components/schemas/Chunk'
scores:
type: array
items:

View file

@ -18,7 +18,7 @@ providers:
- provider_id: ollama
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
url: ${env.OLLAMA_URL:=http://localhost:11434}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
@ -26,7 +26,7 @@ providers:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -38,7 +38,7 @@ providers:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
@ -46,7 +46,7 @@ providers:
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
@ -85,7 +85,7 @@ providers:
# config is a dictionary that contains the configuration for the provider.
# in this case, the configuration is the url of the ollama server
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
url: ${env.OLLAMA_URL:=http://localhost:11434}
```
A few things to note:
- A _provider instance_ is identified with an (id, type, configuration) triplet.
@ -94,6 +94,95 @@ A few things to note:
- The configuration dictionary is provider-specific.
- Notice that configuration can reference environment variables (with default values), which are expanded at runtime. When you run a stack server (via docker or via `llama stack run`), you can specify `--env OLLAMA_URL=http://my-server:11434` to override the default value.
### Environment Variable Substitution
Llama Stack supports environment variable substitution in configuration values using the
`${env.VARIABLE_NAME}` syntax. This allows you to externalize configuration values and provide
different settings for different environments. The syntax is inspired by [bash parameter expansion](https://www.gnu.org/software/bash/manual/html_node/Shell-Parameter-Expansion.html)
and follows similar patterns.
#### Basic Syntax
The basic syntax for environment variable substitution is:
```yaml
config:
api_key: ${env.API_KEY}
url: ${env.SERVICE_URL}
```
If the environment variable is not set, the server will raise an error during startup.
#### Default Values
You can provide default values using the `:=` operator:
```yaml
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
port: ${env.PORT:=8321}
timeout: ${env.TIMEOUT:=60}
```
If the environment variable is not set, the default value `http://localhost:11434` will be used.
Empty defaults are not allowed so `url: ${env.OLLAMA_URL:=}` will raise an error if the environment variable is not set.
#### Conditional Values
You can use the `:+` operator to provide a value only when the environment variable is set:
```yaml
config:
# Only include this field if ENVIRONMENT is set
environment: ${env.ENVIRONMENT:+production}
```
If the environment variable is set, the value after `:+` will be used. If it's not set, the field
will be omitted with a `None` value.
So `${env.ENVIRONMENT:+}` is supported, it means that the field will be omitted if the environment
variable is not set. It can be used to make a field optional and then enabled at runtime when desired.
#### Examples
Here are some common patterns:
```yaml
# Required environment variable (will error if not set)
api_key: ${env.OPENAI_API_KEY}
# Optional with default
base_url: ${env.API_BASE_URL:=https://api.openai.com/v1}
# Conditional field
debug_mode: ${env.DEBUG:+true}
# Optional field that becomes None if not set
optional_token: ${env.OPTIONAL_TOKEN:+}
```
#### Runtime Override
You can override environment variables at runtime when starting the server:
```bash
# Override specific environment variables
llama stack run --config run.yaml --env API_KEY=sk-123 --env BASE_URL=https://custom-api.com
# Or set them in your shell
export API_KEY=sk-123
export BASE_URL=https://custom-api.com
llama stack run --config run.yaml
```
#### Type Safety
The environment variable substitution system is type-safe:
- String values remain strings
- Empty defaults (`${env.VAR:+}`) are converted to `None` for fields that accept `str | None`
- Numeric defaults are properly typed (e.g., `${env.PORT:=8321}` becomes an integer)
- Boolean defaults work correctly (e.g., `${env.DEBUG:=false}` becomes a boolean)
## Resources
Finally, let's look at the `models` section:
@ -109,6 +198,18 @@ A Model is an instance of a "Resource" (see [Concepts](../concepts/index)) and i
What's with the `provider_model_id` field? This is an identifier for the model inside the provider's model catalog. Contrast it with `model_id` which is the identifier for the same model for Llama Stack's purposes. For example, you may want to name "llama3.2:vision-11b" as "image_captioning_model" when you use it in your Stack interactions. When omitted, the server will set `provider_model_id` to be the same as `model_id`.
If you need to conditionally register a model in the configuration, such as only when specific environment variable(s) are set, this can be accomplished by utilizing a special `__disabled__` string as the default value of an environment variable substitution, as shown below:
```yaml
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL:__disabled__}
provider_id: ollama
provider_model_id: ${env.INFERENCE_MODEL:__disabled__}
```
The snippet above will only register this model if the environment variable `INFERENCE_MODEL` is set and non-empty. If the environment variable is not set, the model will not get registered at all.
## Server Configuration
The `server` section configures the HTTP server that serves the Llama Stack APIs:
@ -140,7 +241,7 @@ server:
config:
jwks:
uri: "https://kubernetes.default.svc:8443/openid/v1/jwks"
token: "${env.TOKEN:}"
token: "${env.TOKEN:+}"
key_recheck_period: 3600
tls_cafile: "/path/to/ca.crt"
issuer: "https://kubernetes.default.svc"
@ -384,12 +485,12 @@ providers:
- provider_id: vllm-0
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:http://localhost:8000}
url: ${env.VLLM_URL:=http://localhost:8000}
# this vLLM server serves the llama-guard model (e.g., llama-guard:3b)
- provider_id: vllm-1
provider_type: remote::vllm
config:
url: ${env.SAFETY_VLLM_URL:http://localhost:8001}
url: ${env.SAFETY_VLLM_URL:=http://localhost:8001}
...
models:
- metadata: {}

View file

@ -15,10 +15,10 @@ data:
- provider_id: vllm-inference
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
api_token: ${env.VLLM_API_TOKEN:fake}
tls_verify: ${env.VLLM_TLS_VERIFY:true}
url: ${env.VLLM_URL:=http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: vllm-safety
provider_type: remote::vllm
config:
@ -30,10 +30,10 @@ data:
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:}
url: ${env.CHROMADB_URL:+}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -45,34 +45,34 @@ data:
config:
persistence_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
responses_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:}
service_name: ${env.OTEL_SERVICE_NAME:+}
sinks: ${env.TELEMETRY_SINKS:console}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
@ -82,19 +82,19 @@ data:
config: {}
metadata_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: llamastack_kvstore
inference_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
models:
- metadata:
embedding_dimension: 384
@ -106,11 +106,11 @@ data:
provider_id: vllm-inference
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL:meta-llama/Llama-Guard-3-1B}
model_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
provider_id: vllm-safety
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL:meta-llama/Llama-Guard-3-1B}
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: []
datasets: []
scoring_fns: []

View file

@ -12,25 +12,25 @@ providers:
- provider_id: vllm-inference
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
api_token: ${env.VLLM_API_TOKEN:fake}
tls_verify: ${env.VLLM_TLS_VERIFY:true}
url: ${env.VLLM_URL:=http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: vllm-safety
provider_type: remote::vllm
config:
url: ${env.VLLM_SAFETY_URL:http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
api_token: ${env.VLLM_API_TOKEN:fake}
tls_verify: ${env.VLLM_TLS_VERIFY:true}
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:}
url: ${env.CHROMADB_URL:+}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -42,34 +42,34 @@ providers:
config:
persistence_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
responses_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console}
service_name: ${env.OTEL_SERVICE_NAME:+console}
sinks: ${env.TELEMETRY_SINKS:+console}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
@ -79,19 +79,19 @@ providers:
config: {}
metadata_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: llamastack_kvstore
inference_store:
type: postgres
host: ${env.POSTGRES_HOST:localhost}
port: ${env.POSTGRES_PORT:5432}
db: ${env.POSTGRES_DB:llamastack}
user: ${env.POSTGRES_USER:llamastack}
password: ${env.POSTGRES_PASSWORD:llamastack}
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
models:
- metadata:
embedding_dimension: 384
@ -103,11 +103,11 @@ models:
provider_id: vllm-inference
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL:meta-llama/Llama-Guard-3-1B}
model_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
provider_id: vllm-safety
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL:meta-llama/Llama-Guard-3-1B}
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: []
datasets: []
scoring_fns: []

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .agents import * # noqa: F401 F403
from .agents import *

View file

@ -44,10 +44,55 @@ OpenAIResponseInputMessageContent = Annotated[
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
@json_schema_type
class OpenAIResponseAnnotationFileCitation(BaseModel):
type: Literal["file_citation"] = "file_citation"
file_id: str
filename: str
index: int
@json_schema_type
class OpenAIResponseAnnotationCitation(BaseModel):
type: Literal["url_citation"] = "url_citation"
end_index: int
start_index: int
title: str
url: str
@json_schema_type
class OpenAIResponseAnnotationContainerFileCitation(BaseModel):
type: Literal["container_file_citation"] = "container_file_citation"
container_id: str
end_index: int
file_id: str
filename: str
start_index: int
@json_schema_type
class OpenAIResponseAnnotationFilePath(BaseModel):
type: Literal["file_path"] = "file_path"
file_id: str
index: int
OpenAIResponseAnnotations = Annotated[
OpenAIResponseAnnotationFileCitation
| OpenAIResponseAnnotationCitation
| OpenAIResponseAnnotationContainerFileCitation
| OpenAIResponseAnnotationFilePath,
Field(discriminator="type"),
]
register_schema(OpenAIResponseAnnotations, name="OpenAIResponseAnnotations")
@json_schema_type
class OpenAIResponseOutputMessageContentOutputText(BaseModel):
text: str
type: Literal["output_text"] = "output_text"
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
OpenAIResponseOutputMessageContent = Annotated[
@ -384,9 +429,16 @@ OpenAIResponseInput = Annotated[
register_schema(OpenAIResponseInput, name="OpenAIResponseInput")
# Must match type Literals of OpenAIResponseInputToolWebSearch below
WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11"]
@json_schema_type
class OpenAIResponseInputToolWebSearch(BaseModel):
type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search"
# Must match values of WebSearchToolTypes above
type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = (
"web_search"
)
# TODO: actually use search_context_size somewhere...
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
# TODO: add user_location

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .batch_inference import * # noqa: F401 F403
from .batch_inference import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .benchmarks import * # noqa: F401 F403
from .benchmarks import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datasetio import * # noqa: F401 F403
from .datasetio import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datasets import * # noqa: F401 F403
from .datasets import *

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import Enum, StrEnum
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
@ -13,7 +13,7 @@ from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class DatasetPurpose(str, Enum):
class DatasetPurpose(StrEnum):
"""
Purpose of the dataset. Each purpose has a required input data schema.

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .eval import * # noqa: F401 F403
from .eval import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .files import * # noqa: F401 F403
from .files import *

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import StrEnum
from typing import Annotated, Literal, Protocol, runtime_checkable
from fastapi import File, Form, Response, UploadFile
@ -16,7 +16,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
# OpenAI Files API Models
class OpenAIFilePurpose(str, Enum):
class OpenAIFilePurpose(StrEnum):
"""
Valid purpose values for OpenAI Files API.
"""

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inference import * # noqa: F401 F403
from .inference import *

View file

@ -20,7 +20,7 @@ from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.common.responses import Order
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.apis.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inspect import * # noqa: F401 F403
from .inspect import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .models import * # noqa: F401 F403
from .models import *

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import StrEnum
from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
@ -22,7 +22,7 @@ class CommonModelFields(BaseModel):
@json_schema_type
class ModelType(str, Enum):
class ModelType(StrEnum):
llm = "llm"
embedding = "embedding"

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .post_training import * # noqa: F401 F403
from .post_training import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .providers import * # noqa: F401 F403
from .providers import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .safety import * # noqa: F401 F403
from .safety import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring import * # noqa: F401 F403
from .scoring import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring_functions import * # noqa: F401 F403
from .scoring_functions import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .shields import * # noqa: F401 F403
from .shields import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .synthetic_data_generation import * # noqa: F401 F403
from .synthetic_data_generation import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .telemetry import * # noqa: F401 F403
from .telemetry import *

View file

@ -4,5 +4,5 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .rag_tool import * # noqa: F401 F403
from .tools import * # noqa: F401 F403
from .rag_tool import *
from .tools import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .vector_dbs import * # noqa: F401 F403
from .vector_dbs import *

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .vector_io import * # noqa: F401 F403
from .vector_io import *

View file

@ -8,6 +8,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
@ -15,21 +16,80 @@ from pydantic import BaseModel, Field
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.strong_typing.schema import register_schema
@json_schema_type
class ChunkMetadata(BaseModel):
"""
`ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that
will not be used in the context during inference, but is required for backend functionality. The `ChunkMetadata`
is set during chunk creation in `MemoryToolRuntimeImpl().insert()`and is not expected to change after.
Use `Chunk.metadata` for metadata that will be used in the context during inference.
:param chunk_id: The ID of the chunk. If not set, it will be generated based on the document ID and content.
:param document_id: The ID of the document this chunk belongs to.
:param source: The source of the content, such as a URL, file path, or other identifier.
:param created_timestamp: An optional timestamp indicating when the chunk was created.
:param updated_timestamp: An optional timestamp indicating when the chunk was last updated.
:param chunk_window: The window of the chunk, which can be used to group related chunks together.
:param chunk_tokenizer: The tokenizer used to create the chunk. Default is Tiktoken.
:param chunk_embedding_model: The embedding model used to create the chunk's embedding.
:param chunk_embedding_dimension: The dimension of the embedding vector for the chunk.
:param content_token_count: The number of tokens in the content of the chunk.
:param metadata_token_count: The number of tokens in the metadata of the chunk.
"""
chunk_id: str | None = None
document_id: str | None = None
source: str | None = None
created_timestamp: int | None = None
updated_timestamp: int | None = None
chunk_window: str | None = None
chunk_tokenizer: str | None = None
chunk_embedding_model: str | None = None
chunk_embedding_dimension: int | None = None
content_token_count: int | None = None
metadata_token_count: int | None = None
@json_schema_type
class Chunk(BaseModel):
"""
A chunk of content that can be inserted into a vector database.
:param content: The content of the chunk, which can be interleaved text, images, or other types.
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
:param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
:param metadata: Metadata associated with the chunk that will be used in the model context during inference.
:param stored_chunk_id: The chunk ID that is stored in the vector database. Used for backend functionality.
:param chunk_metadata: Metadata for the chunk that will NOT be used in the context during inference.
The `chunk_metadata` is required backend functionality.
"""
content: InterleavedContent
metadata: dict[str, Any] = Field(default_factory=dict)
embedding: list[float] | None = None
# The alias parameter serializes the field as "chunk_id" in JSON but keeps the internal name as "stored_chunk_id"
stored_chunk_id: str | None = Field(default=None, alias="chunk_id")
chunk_metadata: ChunkMetadata | None = None
model_config = {"populate_by_name": True}
def model_post_init(self, __context):
# Extract chunk_id from metadata if present
if self.metadata and "chunk_id" in self.metadata:
self.stored_chunk_id = self.metadata.pop("chunk_id")
@property
def chunk_id(self) -> str:
"""Returns the chunk ID, which is either an input `chunk_id` or a generated one if not set."""
if self.stored_chunk_id:
return self.stored_chunk_id
if "document_id" in self.metadata:
return generate_chunk_id(self.metadata["document_id"], str(self.content))
return generate_chunk_id(str(uuid.uuid4()), str(self.content))
@json_schema_type

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import StrEnum
from typing import Self
from pydantic import BaseModel, model_validator
@ -12,7 +12,7 @@ from pydantic import BaseModel, model_validator
from .conditions import parse_conditions
class Action(str, Enum):
class Action(StrEnum):
CREATE = "create"
READ = "read"
UPDATE = "update"

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Any
@ -29,8 +29,8 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
LLAMA_STACK_RUN_CONFIG_VERSION = 2
RoutingKey = str | list[str]
@ -159,7 +159,7 @@ class LoggingConfig(BaseModel):
)
class AuthProviderType(str, Enum):
class AuthProviderType(StrEnum):
"""Supported authentication provider types."""
OAUTH2_TOKEN = "oauth2_token"
@ -182,7 +182,7 @@ class AuthenticationRequiredError(Exception):
pass
class QuotaPeriod(str, Enum):
class QuotaPeriod(StrEnum):
DAY = "day"
@ -229,7 +229,7 @@ class ServerConfig(BaseModel):
class StackRunConfig(BaseModel):
version: str = LLAMA_STACK_RUN_CONFIG_VERSION
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
image_name: str = Field(
...,
@ -300,7 +300,7 @@ a default SQLite store will be used.""",
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
version: int = LLAMA_STACK_BUILD_CONFIG_VERSION
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
image_type: str = Field(

View file

@ -30,7 +30,13 @@ from llama_stack.apis.inference import (
ListOpenAIChatCompletionResponse,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAICompletionWithInputMessages,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
Order,
ResponseFormat,
SamplingParams,
@ -41,14 +47,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.log import get_logger

View file

@ -16,17 +16,15 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.apis.vector_io.vector_io import (
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable

View file

@ -98,6 +98,15 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method)
for obj in objects:
# In complex templates, like our starter template, we may have dynamic model ids
# given by environment variables. This allows those environment variables to have
# a default value of __disabled__ to skip registration of the model if not set.
if (
hasattr(obj, "provider_model_id")
and obj.provider_model_id is not None
and "__disabled__" in obj.provider_model_id
):
continue
# we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
@ -118,7 +127,12 @@ class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name
self.path = path
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
super().__init__(
f"Environment variable '{var_name}' not set or empty {f'at {path}' if path else ''}. "
f"Use ${{env.{var_name}:=default_value}} to provide a default value, "
f"${{env.{var_name}:+value_if_set}} to make the field conditional, "
f"or ensure the environment variable is set."
)
def replace_env_vars(config: Any, path: str = "") -> Any:
@ -141,25 +155,27 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
return result
elif isinstance(config, str):
# Updated pattern to support both default values (:) and conditional values (+)
pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}"
# Pattern supports bash-like syntax: := for default and :+ for conditional and a optional value
pattern = r"\${env\.([A-Z0-9_]+)(?::([=+])([^}]*))?}"
def get_env_var(match):
def get_env_var(match: re.Match):
env_var = match.group(1)
operator = match.group(2) # ':' for default, '+' for conditional
operator = match.group(2) # '=' for default, '+' for conditional
value_expr = match.group(3)
env_value = os.environ.get(env_var)
if operator == ":": # Default value syntax: ${env.FOO:default}
if operator == "=": # Default value syntax: ${env.FOO:=default}
if not env_value:
if value_expr is None:
# value_expr returns empty string (not None) when not matched
# This means ${env.FOO:=} is an error
if value_expr == "":
raise EnvVarError(env_var, path)
else:
value = value_expr
else:
value = env_value
elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set}
elif operator == "+": # Conditional value syntax: ${env.FOO:+value_if_set}
if env_value:
value = value_expr
else:
@ -174,13 +190,42 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
return os.path.expanduser(value)
try:
return re.sub(pattern, get_env_var, config)
result = re.sub(pattern, get_env_var, config)
return _convert_string_to_proper_type(result)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return config
def _convert_string_to_proper_type(value: str) -> Any:
# This might be tricky depending on what the config type is, if 'str | None' we are
# good, if 'str' we need to keep the empty string... 'str | None' is more common and
# providers config should be typed this way.
# TODO: we could try to load the config class and see if the config has a field with type 'str | None'
# and then convert the empty string to None or not
if value == "":
return None
lowered = value.lower()
if lowered == "true":
return True
elif lowered == "false":
return False
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
def validate_env_pair(env_pair: str) -> tuple[str, str]:
"""Validate and split an environment variable key-value pair."""
try:

View file

@ -25,7 +25,7 @@ class LlamaStackApi:
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None):
"""Run scoring on a single row"""
if not scoring_params:
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
scoring_params = dict.fromkeys(scoring_function_ids)
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)

View file

@ -33,7 +33,7 @@ CATEGORIES = [
]
# Initialize category levels with default level
_category_levels: dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
_category_levels: dict[str, int] = dict.fromkeys(CATEGORIES, DEFAULT_LOG_LEVEL)
def config_to_category_levels(category: str, level: str):

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import base64
from enum import Enum
from enum import Enum, StrEnum
from io import BytesIO
from typing import Annotated, Any, Literal
@ -171,7 +171,7 @@ class GenerationResult(BaseModel):
ignore_token: bool
class QuantizationMode(str, Enum):
class QuantizationMode(StrEnum):
none = "none"
fp8_mixed = "fp8_mixed"
int4_mixed = "int4_mixed"

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import StrEnum
from typing import Any, Protocol
from urllib.parse import urlparse
@ -225,7 +225,7 @@ def remote_provider_spec(
)
class HealthStatus(str, Enum):
class HealthStatus(StrEnum):
OK = "OK"
ERROR = "Error"
NOT_IMPLEMENTED = "Not Implemented"

View file

@ -42,9 +42,10 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
WebSearchToolTypes,
)
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference.inference import (
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
@ -583,7 +584,7 @@ class OpenAIResponsesImpl:
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
)
from llama_stack.apis.tools.tools import Tool
from llama_stack.apis.tools import Tool
mcp_tool_to_server = {}
@ -609,7 +610,7 @@ class OpenAIResponsesImpl:
# TODO: Handle other tool types
if input_tool.type == "function":
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
elif input_tool.type == "web_search":
elif input_tool.type in WebSearchToolTypes:
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
if not tool:

View file

@ -208,7 +208,7 @@ class MetaReferenceEvalImpl(
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
scoring_functions_dict = dict.fromkeys(scoring_functions)
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict

View file

@ -23,7 +23,7 @@ class LocalfsFilesImplConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"storage_dir": "${env.FILES_STORAGE_DIR:" + __distro_dir__ + "/files}",
"storage_dir": "${env.FILES_STORAGE_DIR:=" + __distro_dir__ + "/files}",
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="files_metadata.db",

View file

@ -49,11 +49,11 @@ class MetaReferenceInferenceConfig(BaseModel):
def sample_run_config(
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:=null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:=bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:=0}",
max_batch_size: str = "${env.MAX_BATCH_SIZE:=1}",
max_seq_len: str = "${env.MAX_SEQ_LEN:=4096}",
**kwargs,
) -> dict[str, Any]:
return {

View file

@ -44,10 +44,10 @@ class VLLMConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.MAX_TOKENS:4096}",
"max_model_len": "${env.MAX_MODEL_LEN:4096}",
"max_num_seqs": "${env.MAX_NUM_SEQS:4}",
"enforce_eager": "${env.ENFORCE_EAGER:False}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}",
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:=1}",
"max_tokens": "${env.MAX_TOKENS:=4096}",
"max_model_len": "${env.MAX_MODEL_LEN:=4096}",
"max_num_seqs": "${env.MAX_NUM_SEQS:=4}",
"enforce_eager": "${env.ENFORCE_EAGER:=False}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:=0.3}",
}

View file

@ -17,5 +17,5 @@ class BraintrustScoringConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"openai_api_key": "${env.OPENAI_API_KEY:}",
"openai_api_key": "${env.OPENAI_API_KEY:+}",
}

View file

@ -7,7 +7,7 @@ from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,

View file

@ -6,7 +6,7 @@
import re
from typing import Any
from llama_stack.apis.inference.inference import Inference, UserMessage
from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field, field_validator
@ -12,7 +12,7 @@ from pydantic import BaseModel, Field, field_validator
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class TelemetrySink(str, Enum):
class TelemetrySink(StrEnum):
OTEL_TRACE = "otel_trace"
OTEL_METRIC = "otel_metric"
SQLITE = "sqlite"
@ -20,12 +20,12 @@ class TelemetrySink(str, Enum):
class TelemetryConfig(BaseModel):
otel_trace_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
otel_trace_endpoint: str | None = Field(
default=None,
description="The OpenTelemetry collector endpoint URL for traces",
)
otel_metric_endpoint: str = Field(
default="http://localhost:4318/v1/metrics",
otel_metric_endpoint: str | None = Field(
default=None,
description="The OpenTelemetry collector endpoint URL for metrics",
)
service_name: str = Field(
@ -52,7 +52,7 @@ class TelemetryConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:\u200b}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
"service_name": "${env.OTEL_SERVICE_NAME:=\u200b}",
"sinks": "${env.TELEMETRY_SINKS:=console,sqlite}",
"sqlite_db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
}

View file

@ -87,12 +87,16 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
trace.set_tracer_provider(provider)
_TRACER_PROVIDER = provider
if TelemetrySink.OTEL_TRACE in self.config.sinks:
if self.config.otel_trace_endpoint is None:
raise ValueError("otel_trace_endpoint is required when OTEL_TRACE is enabled")
span_exporter = OTLPSpanExporter(
endpoint=self.config.otel_trace_endpoint,
)
span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
if TelemetrySink.OTEL_METRIC in self.config.sinks:
if self.config.otel_metric_endpoint is None:
raise ValueError("otel_metric_endpoint is required when OTEL_METRIC is enabled")
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_metric_endpoint,

View file

@ -81,6 +81,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
chunks = []
for doc in documents:
content = await content_from_doc(doc)
# TODO: we should add enrichment here as URLs won't be added to the metadata by default
chunks.extend(
make_overlapped_chunks(
doc.document_id,
@ -157,8 +158,24 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
)
break
metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
# Add useful keys from chunk_metadata to metadata and remove some from metadata
chunk_metadata_keys_to_include_from_context = [
"chunk_id",
"document_id",
"source",
]
metadata_keys_to_exclude_from_context = [
"token_count",
"metadata_token_count",
]
metadata_for_context = {}
for k in chunk_metadata_keys_to_include_from_context:
metadata_for_context[k] = getattr(chunk.chunk_metadata, k)
for k in metadata:
if k not in metadata_keys_to_exclude_from_context:
metadata_for_context[k] = metadata[k]
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_for_context)
picked.append(TextContentItem(text=text_content))
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))

View file

@ -16,8 +16,7 @@ import numpy as np
from numpy.typing import NDArray
from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,

View file

@ -19,5 +19,5 @@ class QdrantVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"path": "${env.QDRANT_PATH:~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
"path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
}

View file

@ -15,5 +15,5 @@ class SQLiteVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + "sqlite_vec.db",
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
}

View file

@ -5,20 +5,18 @@
# the root directory of this source tree.
import asyncio
import hashlib
import json
import logging
import sqlite3
import struct
import uuid
from typing import Any
import numpy as np
import sqlite_vec
from numpy.typing import NDArray
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
@ -66,7 +64,7 @@ def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
score_range = max_score - min_score
if score_range > 0:
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
return {doc_id: 1.0 for doc_id in scores}
return dict.fromkeys(scores, 1.0)
def _weighted_rerank(
@ -201,10 +199,7 @@ class SQLiteVecIndex(EmbeddingIndex):
batch_embeddings = embeddings[i : i + batch_size]
# Insert metadata
metadata_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for chunk in batch_chunks
]
metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks]
cur.executemany(
f"""
INSERT INTO {self.metadata_table} (id, chunk)
@ -218,7 +213,7 @@ class SQLiteVecIndex(EmbeddingIndex):
embedding_data = [
(
(
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
chunk.chunk_id,
serialize_vector(emb.tolist()),
)
)
@ -230,10 +225,7 @@ class SQLiteVecIndex(EmbeddingIndex):
)
# Insert FTS content
fts_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content)
for chunk in batch_chunks
]
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
cur.executemany(
f"DELETE FROM {self.fts_table} WHERE id = ?;",
@ -381,13 +373,12 @@ class SQLiteVecIndex(EmbeddingIndex):
vector_response = await self.query_vector(embedding, k, score_threshold)
keyword_response = await self.query_keyword(query_string, k, score_threshold)
# Convert responses to score dictionaries using generate_chunk_id
# Convert responses to score dictionaries using chunk_id
vector_scores = {
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
}
keyword_scores = {
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
chunk.chunk_id: score
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
}
@ -408,13 +399,7 @@ class SQLiteVecIndex(EmbeddingIndex):
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
# Create a map of chunk_id to chunk for both responses
chunk_map = {}
for c in vector_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_map[chunk_id] = c
for c in keyword_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_map[chunk_id] = c
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
# Use the map to look up chunks by their IDs
chunks = []
@ -757,9 +742,3 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode()
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))

View file

@ -70,7 +70,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="ollama",
pip_packages=["ollama", "aiohttp"],
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.remote.inference.ollama",
),

View file

@ -67,7 +67,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.safety,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=["litellm"],
pip_packages=["litellm", "requests"],
module="llama_stack.providers.remote.safety.sambanova",
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",

View file

@ -13,7 +13,7 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::basic",
pip_packages=[],
pip_packages=["requests"],
module="llama_stack.providers.inline.scoring.basic",
config_class="llama_stack.providers.inline.scoring.basic.BasicScoringConfig",
api_dependencies=[

View file

@ -54,8 +54,8 @@ class NvidiaDatasetIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"api_key": "${env.NVIDIA_API_KEY:}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}",
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
"datasets_url": "${env.NVIDIA_DATASETS_URL:http://nemo.test}",
"api_key": "${env.NVIDIA_API_KEY:+}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:=default}",
"project_id": "${env.NVIDIA_PROJECT_ID:=test-project}",
"datasets_url": "${env.NVIDIA_DATASETS_URL:=http://nemo.test}",
}

View file

@ -66,7 +66,7 @@ class NvidiaDatasetIOAdapter:
Returns:
Dataset
"""
## add warnings for unsupported params
# add warnings for unsupported params
request_body = {
"name": dataset_def.identifier,
"namespace": self.config.dataset_namespace,

View file

@ -25,5 +25,5 @@ class NVIDIAEvalConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}",
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}",
}

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)

View file

@ -24,6 +24,12 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
ResponseFormatType,
SamplingParams,
@ -33,14 +39,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)

View file

@ -9,7 +9,7 @@ from typing import Any
from openai import AsyncOpenAI
from llama_stack.apis.inference.inference import (
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChoiceDelta,

View file

@ -55,7 +55,7 @@ class NVIDIAConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
"api_key": "${env.NVIDIA_API_KEY:}",
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
"url": "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}",
"api_key": "${env.NVIDIA_API_KEY:+}",
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:=True}",
}

View file

@ -29,20 +29,18 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference import (

View file

@ -10,6 +10,6 @@ from .config import OllamaImplConfig
async def get_adapter_impl(config: OllamaImplConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config.url)
impl = OllamaInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -13,7 +13,13 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL
raise_on_connect_error: bool = True
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", **kwargs) -> dict[str, Any]:
return {"url": url}
def sample_run_config(
cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", raise_on_connect_error: bool = True, **kwargs
) -> dict[str, Any]:
return {
"url": url,
"raise_on_connect_error": raise_on_connect_error,
}

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,

View file

@ -9,7 +9,6 @@ import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
import httpx
from ollama import AsyncClient # type: ignore[attr-defined]
from openai import AsyncOpenAI
@ -33,6 +32,13 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -41,15 +47,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
@ -57,6 +54,7 @@ from llama_stack.providers.datatypes import (
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -90,9 +88,10 @@ class OllamaInferenceAdapter(
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, url: str) -> None:
def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.url = url
self.url = config.url
self.raise_on_connect_error = config.raise_on_connect_error
@property
def client(self) -> AsyncClient:
@ -103,8 +102,13 @@ class OllamaInferenceAdapter(
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama")
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
await self.health()
logger.debug(f"checking connectivity to Ollama at `{self.url}`...")
health_response = await self.health()
if health_response["status"] == HealthStatus.ERROR:
if self.raise_on_connect_error:
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
else:
logger.warning("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
async def health(self) -> HealthResponse:
"""
@ -117,10 +121,8 @@ class OllamaInferenceAdapter(
try:
await self.client.ps()
return HealthResponse(status=HealthStatus.OK)
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
) from e
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def shutdown(self) -> None:
pass

View file

@ -6,7 +6,7 @@
from dataclasses import dataclass
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)

View file

@ -10,7 +10,7 @@ from typing import Any
from openai import AsyncOpenAI
from llama_stack.apis.inference.inference import (
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,

View file

@ -19,7 +19,12 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -28,13 +33,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper

View file

@ -25,6 +25,6 @@ class RunpodImplConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"url": "${env.RUNPOD_URL:}",
"api_token": "${env.RUNPOD_API_TOKEN:}",
"url": "${env.RUNPOD_URL:+}",
"api_token": "${env.RUNPOD_API_TOKEN:+}",
}

View file

@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference.inference import OpenAIEmbeddingsResponse
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper

View file

@ -26,5 +26,5 @@ class TogetherImplConfig(BaseModel):
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY:}",
"api_key": "${env.TOGETHER_API_KEY:+}",
}

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,

View file

@ -23,7 +23,12 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
ResponseFormatType,
SamplingParams,
@ -33,13 +38,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper

View file

@ -34,9 +34,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
@classmethod
def validate_tls_verify(cls, v):
if isinstance(v, str):
# Check if it's a boolean string
if v.lower() in ("true", "false"):
return v.lower() == "true"
# Otherwise, treat it as a cert path
cert_path = Path(v).expanduser().resolve()
if not cert_path.exists():
@ -54,7 +51,7 @@ class VLLMInferenceAdapterConfig(BaseModel):
):
return {
"url": url,
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
"api_token": "${env.VLLM_API_TOKEN:fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:true}",
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
"api_token": "${env.VLLM_API_TOKEN:=fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
}

View file

@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
import httpx
from openai import AsyncOpenAI
from openai import APIConnectionError, AsyncOpenAI
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
@ -38,9 +38,13 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -49,12 +53,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
@ -461,7 +459,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
res = await client.models.list()
try:
res = await client.models.list()
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
) from e
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(

View file

@ -40,7 +40,7 @@ class WatsonXConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:}",
"project_id": "${env.WATSONX_PROJECT_ID:}",
"url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:+}",
"project_id": "${env.WATSONX_PROJECT_ID:+}",
}

View file

@ -18,10 +18,16 @@ from llama_stack.apis.inference import (
CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
GreedySamplingStrategy,
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -29,14 +35,6 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
GreedySamplingStrategy,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
TopKSamplingStrategy,
TopPSamplingStrategy,
)

View file

@ -55,10 +55,10 @@ class NvidiaPostTrainingConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"api_key": "${env.NVIDIA_API_KEY:}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}",
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}",
"api_key": "${env.NVIDIA_API_KEY:+}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:=default}",
"project_id": "${env.NVIDIA_PROJECT_ID:=test-project}",
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}",
}

View file

@ -35,6 +35,6 @@ class NVIDIASafetyConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
"config_id": "${env.NVIDIA_GUARDRAILS_CONFIG_ID:self-check}",
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}",
"config_id": "${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}",
}

View file

@ -22,6 +22,6 @@ class BraveSearchToolConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"api_key": "${env.BRAVE_SEARCH_API_KEY:}",
"api_key": "${env.BRAVE_SEARCH_API_KEY:+}",
"max_results": 3,
}

View file

@ -22,6 +22,6 @@ class TavilySearchToolConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"api_key": "${env.TAVILY_SEARCH_API_KEY:}",
"api_key": "${env.TAVILY_SEARCH_API_KEY:+}",
"max_results": 3,
}

View file

@ -17,5 +17,5 @@ class WolframAlphaToolConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"api_key": "${env.WOLFRAM_ALPHA_API_KEY:}",
"api_key": "${env.WOLFRAM_ALPHA_API_KEY:+}",
}

View file

@ -22,8 +22,8 @@ class PGVectorVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(
cls,
host: str = "${env.PGVECTOR_HOST:localhost}",
port: int = "${env.PGVECTOR_PORT:5432}",
host: str = "${env.PGVECTOR_HOST:=localhost}",
port: int = "${env.PGVECTOR_PORT:=5432}",
db: str = "${env.PGVECTOR_DB}",
user: str = "${env.PGVECTOR_USER}",
password: str = "${env.PGVECTOR_PASSWORD}",

View file

@ -70,8 +70,8 @@ class QdrantIndex(EmbeddingIndex):
)
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
for _i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
chunk_id = chunk.chunk_id
points.append(
PointStruct(
id=convert_id(chunk_id),

View file

@ -23,6 +23,13 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -31,16 +38,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models.models import Model
from llama_stack.apis.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import (

View file

@ -95,27 +95,25 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
GreedySamplingStrategy,
Message,
SamplingParams,
SystemMessage,
TokenLogProbs,
ToolChoice,
ToolResponseMessage,
TopKSamplingStrategy,
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.apis.inference.inference import (
JsonSchemaResponseFormat,
Message,
OpenAIChatCompletion,
OpenAICompletion,
OpenAICompletionChoice,
OpenAIEmbeddingData,
OpenAIMessageParam,
OpenAIResponseFormatParam,
SamplingParams,
SystemMessage,
TokenLogProbs,
ToolChoice,
ToolConfig,
ToolResponseMessage,
TopKSamplingStrategy,
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.apis.inference.inference import (
from llama_stack.apis.inference import (
OpenAIChoice as OpenAIChatCompletionChoice,
)
from llama_stack.models.llama.datatypes import (
@ -1026,7 +1024,9 @@ def openai_messages_to_messages(
return converted_messages
def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam]):
def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam] | None):
if content is None:
return ""
if isinstance(content, str):
return content
elif isinstance(content, list):

View file

@ -45,8 +45,8 @@ class RedisKVStoreConfig(CommonConfig):
return {
"type": "redis",
"namespace": None,
"host": "${env.REDIS_HOST:localhost}",
"port": "${env.REDIS_PORT:6379}",
"host": "${env.REDIS_HOST:=localhost}",
"port": "${env.REDIS_PORT:=6379}",
}
@ -66,7 +66,7 @@ class SqliteKVStoreConfig(CommonConfig):
return {
"type": "sqlite",
"namespace": None,
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
}
@ -84,12 +84,12 @@ class PostgresKVStoreConfig(CommonConfig):
return {
"type": "postgres",
"namespace": None,
"host": "${env.POSTGRES_HOST:localhost}",
"port": "${env.POSTGRES_PORT:5432}",
"db": "${env.POSTGRES_DB:llamastack}",
"user": "${env.POSTGRES_USER:llamastack}",
"password": "${env.POSTGRES_PASSWORD:llamastack}",
"table_name": "${env.POSTGRES_TABLE_NAME:" + table_name + "}",
"host": "${env.POSTGRES_HOST:=localhost}",
"port": "${env.POSTGRES_PORT:=5432}",
"db": "${env.POSTGRES_DB:=llamastack}",
"user": "${env.POSTGRES_USER:=llamastack}",
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
"table_name": "${env.POSTGRES_TABLE_NAME:=" + table_name + "}",
}
@classmethod
@ -131,12 +131,12 @@ class MongoDBKVStoreConfig(CommonConfig):
return {
"type": "mongodb",
"namespace": None,
"host": "${env.MONGODB_HOST:localhost}",
"port": "${env.MONGODB_PORT:5432}",
"host": "${env.MONGODB_HOST:=localhost}",
"port": "${env.MONGODB_PORT:=5432}",
"db": "${env.MONGODB_DB}",
"user": "${env.MONGODB_USER}",
"password": "${env.MONGODB_PASSWORD}",
"collection_name": "${env.MONGODB_COLLECTION_NAME:" + collection_name + "}",
"collection_name": "${env.MONGODB_COLLECTION_NAME:=" + collection_name + "}",
}

View file

@ -12,8 +12,7 @@ import uuid
from abc import ABC, abstractmethod
from typing import Any
from llama_stack.apis.files import Files
from llama_stack.apis.files.files import OpenAIFileObject
from llama_stack.apis.files import Files, OpenAIFileObject
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,

View file

@ -7,6 +7,7 @@ import base64
import io
import logging
import re
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
@ -23,12 +24,13 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id
log = logging.getLogger(__name__)
@ -148,6 +150,7 @@ async def content_from_doc(doc: RAGDocument) -> str:
def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
) -> list[Chunk]:
default_tokenizer = "DEFAULT_TIKTOKEN_TOKENIZER"
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
try:
@ -161,16 +164,32 @@ def make_overlapped_chunks(
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
chunk_id = generate_chunk_id(chunk, text)
chunk_metadata = metadata.copy()
chunk_metadata["chunk_id"] = chunk_id
chunk_metadata["document_id"] = document_id
chunk_metadata["token_count"] = len(toks)
chunk_metadata["metadata_token_count"] = len(metadata_tokens)
backend_chunk_metadata = ChunkMetadata(
chunk_id=chunk_id,
document_id=document_id,
source=metadata.get("source", None),
created_timestamp=metadata.get("created_timestamp", int(time.time())),
updated_timestamp=int(time.time()),
chunk_window=f"{i}-{i + len(toks)}",
chunk_tokenizer=default_tokenizer,
chunk_embedding_model=None, # This will be set in `VectorDBWithIndex.insert_chunks`
content_token_count=len(toks),
metadata_token_count=len(metadata_tokens),
)
# chunk is a string
chunks.append(
Chunk(
content=chunk,
metadata=chunk_metadata,
chunk_metadata=backend_chunk_metadata,
)
)
@ -237,6 +256,9 @@ class VectorDBWithIndex:
for i, c in enumerate(chunks):
if c.embedding is None:
chunks_to_embed.append(c)
if c.chunk_metadata:
c.chunk_metadata.chunk_embedding_model = self.vector_db.embedding_model
c.chunk_metadata.chunk_embedding_dimension = self.vector_db.embedding_dimension
else:
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)

View file

@ -50,7 +50,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"):
return cls(
type="sqlite",
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
db_path="${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
)
@property
@ -78,11 +78,11 @@ class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
def sample_run_config(cls, **kwargs):
return cls(
type="postgres",
host="${env.POSTGRES_HOST:localhost}",
port="${env.POSTGRES_PORT:5432}",
db="${env.POSTGRES_DB:llamastack}",
user="${env.POSTGRES_USER:llamastack}",
password="${env.POSTGRES_PASSWORD:llamastack}",
host="${env.POSTGRES_HOST:=localhost}",
port="${env.POSTGRES_PORT:=5432}",
db="${env.POSTGRES_DB:=llamastack}",
user="${env.POSTGRES_USER:=llamastack}",
password="${env.POSTGRES_PASSWORD:=llamastack}",
)

Some files were not shown because too many files have changed in this diff Show more