mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 19:12:09 +00:00
Merge branch 'pr1573' into api_2
This commit is contained in:
commit
d7dbc8cf64
21 changed files with 673 additions and 232 deletions
2
.github/workflows/unit-tests.yml
vendored
2
.github/workflows/unit-tests.yml
vendored
|
@ -1,6 +1,8 @@
|
||||||
name: Unit Tests
|
name: Unit Tests
|
||||||
|
|
||||||
on:
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
[](https://pypi.org/project/llama-stack/)
|
[](https://pypi.org/project/llama-stack/)
|
||||||
[](https://github.com/meta-llama/llama-stack/blob/main/LICENSE)
|
[](https://github.com/meta-llama/llama-stack/blob/main/LICENSE)
|
||||||
[](https://discord.gg/llama-stack)
|
[](https://discord.gg/llama-stack)
|
||||||
|

|
||||||
|
|
||||||
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
|
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
|
||||||
|
|
||||||
|
|
198
docs/_static/llama-stack-spec.html
vendored
198
docs/_static/llama-stack-spec.html
vendored
|
@ -4570,7 +4570,7 @@
|
||||||
"metrics": {
|
"metrics": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/MetricEvent"
|
"$ref": "#/components/schemas/MetricInResponse"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"completion_message": {
|
"completion_message": {
|
||||||
|
@ -4592,46 +4592,9 @@
|
||||||
"title": "ChatCompletionResponse",
|
"title": "ChatCompletionResponse",
|
||||||
"description": "Response from a chat completion request."
|
"description": "Response from a chat completion request."
|
||||||
},
|
},
|
||||||
"MetricEvent": {
|
"MetricInResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"trace_id": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"span_id": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"timestamp": {
|
|
||||||
"type": "string",
|
|
||||||
"format": "date-time"
|
|
||||||
},
|
|
||||||
"attributes": {
|
|
||||||
"type": "object",
|
|
||||||
"additionalProperties": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "integer"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "number"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "boolean"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "null"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"type": "string",
|
|
||||||
"const": "metric",
|
|
||||||
"default": "metric"
|
|
||||||
},
|
|
||||||
"metric": {
|
"metric": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
@ -4651,15 +4614,10 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"trace_id",
|
|
||||||
"span_id",
|
|
||||||
"timestamp",
|
|
||||||
"type",
|
|
||||||
"metric",
|
"metric",
|
||||||
"value",
|
"value"
|
||||||
"unit"
|
|
||||||
],
|
],
|
||||||
"title": "MetricEvent"
|
"title": "MetricInResponse"
|
||||||
},
|
},
|
||||||
"TokenLogProbs": {
|
"TokenLogProbs": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -4736,6 +4694,12 @@
|
||||||
"CompletionResponse": {
|
"CompletionResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"metrics": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/MetricInResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
"content": {
|
"content": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The generated completion text"
|
"description": "The generated completion text"
|
||||||
|
@ -4945,7 +4909,7 @@
|
||||||
"metrics": {
|
"metrics": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/MetricEvent"
|
"$ref": "#/components/schemas/MetricInResponse"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"event": {
|
"event": {
|
||||||
|
@ -5103,6 +5067,12 @@
|
||||||
"CompletionResponseStreamChunk": {
|
"CompletionResponseStreamChunk": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"metrics": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/MetricInResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
"delta": {
|
"delta": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "New content generated since last chunk. This can be one or more tokens."
|
"description": "New content generated since last chunk. This can be one or more tokens."
|
||||||
|
@ -7192,15 +7162,16 @@
|
||||||
"const": "dataset",
|
"const": "dataset",
|
||||||
"default": "dataset"
|
"default": "dataset"
|
||||||
},
|
},
|
||||||
"schema": {
|
"purpose": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": [
|
"enum": [
|
||||||
"messages"
|
"post-training/messages",
|
||||||
|
"eval/question-answer"
|
||||||
],
|
],
|
||||||
"title": "Schema",
|
"title": "DatasetPurpose",
|
||||||
"description": "Schema of the dataset. Each type has a different column format."
|
"description": "Purpose of the dataset. Each type has a different column format."
|
||||||
},
|
},
|
||||||
"data_source": {
|
"source": {
|
||||||
"$ref": "#/components/schemas/DataSource"
|
"$ref": "#/components/schemas/DataSource"
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -7235,8 +7206,8 @@
|
||||||
"provider_resource_id",
|
"provider_resource_id",
|
||||||
"provider_id",
|
"provider_id",
|
||||||
"type",
|
"type",
|
||||||
"schema",
|
"purpose",
|
||||||
"data_source",
|
"source",
|
||||||
"metadata"
|
"metadata"
|
||||||
],
|
],
|
||||||
"title": "Dataset"
|
"title": "Dataset"
|
||||||
|
@ -7249,8 +7220,9 @@
|
||||||
"const": "huggingface",
|
"const": "huggingface",
|
||||||
"default": "huggingface"
|
"default": "huggingface"
|
||||||
},
|
},
|
||||||
"dataset_path": {
|
"path": {
|
||||||
"type": "string"
|
"type": "string",
|
||||||
|
"description": "The path to the dataset in Huggingface. E.g. - \"llamastack/simpleqa\""
|
||||||
},
|
},
|
||||||
"params": {
|
"params": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -7275,16 +7247,18 @@
|
||||||
"type": "object"
|
"type": "object"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
|
"description": "The parameters for the dataset."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"type",
|
"type",
|
||||||
"dataset_path",
|
"path",
|
||||||
"params"
|
"params"
|
||||||
],
|
],
|
||||||
"title": "HuggingfaceDataSource"
|
"title": "HuggingfaceDataSource",
|
||||||
|
"description": "A dataset stored in Huggingface."
|
||||||
},
|
},
|
||||||
"RowsDataSource": {
|
"RowsDataSource": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -7320,7 +7294,8 @@
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"description": "The dataset is stored in rows. E.g. - [ {\"messages\": [{\"role\": \"user\", \"content\": \"Hello, world!\"}, {\"role\": \"assistant\", \"content\": \"Hello, world!\"}]} ]"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -7328,7 +7303,8 @@
|
||||||
"type",
|
"type",
|
||||||
"rows"
|
"rows"
|
||||||
],
|
],
|
||||||
"title": "RowsDataSource"
|
"title": "RowsDataSource",
|
||||||
|
"description": "A dataset stored in rows."
|
||||||
},
|
},
|
||||||
"URIDataSource": {
|
"URIDataSource": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -7339,7 +7315,8 @@
|
||||||
"default": "uri"
|
"default": "uri"
|
||||||
},
|
},
|
||||||
"uri": {
|
"uri": {
|
||||||
"type": "string"
|
"type": "string",
|
||||||
|
"description": "The dataset can be obtained from a URI. E.g. - \"https://mywebsite.com/mydata.jsonl\" - \"lsfs://mydata.jsonl\" - \"data:csv;base64,{base64_content}\""
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -7347,7 +7324,8 @@
|
||||||
"type",
|
"type",
|
||||||
"uri"
|
"uri"
|
||||||
],
|
],
|
||||||
"title": "URIDataSource"
|
"title": "URIDataSource",
|
||||||
|
"description": "A dataset that can be obtained from a URI."
|
||||||
},
|
},
|
||||||
"Model": {
|
"Model": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -8634,6 +8612,75 @@
|
||||||
],
|
],
|
||||||
"title": "LogSeverity"
|
"title": "LogSeverity"
|
||||||
},
|
},
|
||||||
|
"MetricEvent": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"trace_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"span_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"timestamp": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "date-time"
|
||||||
|
},
|
||||||
|
"attributes": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "metric",
|
||||||
|
"default": "metric"
|
||||||
|
},
|
||||||
|
"metric": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"value": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"trace_id",
|
||||||
|
"span_id",
|
||||||
|
"timestamp",
|
||||||
|
"type",
|
||||||
|
"metric",
|
||||||
|
"value",
|
||||||
|
"unit"
|
||||||
|
],
|
||||||
|
"title": "MetricEvent"
|
||||||
|
},
|
||||||
"SpanEndPayload": {
|
"SpanEndPayload": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -9510,14 +9557,15 @@
|
||||||
"RegisterDatasetRequest": {
|
"RegisterDatasetRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"schema": {
|
"purpose": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": [
|
"enum": [
|
||||||
"messages"
|
"post-training/messages",
|
||||||
|
"eval/question-answer"
|
||||||
],
|
],
|
||||||
"description": "The schema format of the dataset. One of - messages: The dataset contains a messages column with list of messages for post-training."
|
"description": "The purpose of the dataset. One of - \"post-training/messages\": The dataset contains a messages column with list of messages for post-training. - \"eval/question-answer\": The dataset contains a question and answer column."
|
||||||
},
|
},
|
||||||
"data_source": {
|
"source": {
|
||||||
"$ref": "#/components/schemas/DataSource",
|
"$ref": "#/components/schemas/DataSource",
|
||||||
"description": "The data source of the dataset. Examples: - { \"type\": \"uri\", \"uri\": \"https://mywebsite.com/mydata.jsonl\" } - { \"type\": \"uri\", \"uri\": \"lsfs://mydata.jsonl\" } - { \"type\": \"huggingface\", \"dataset_path\": \"tatsu-lab/alpaca\", \"params\": { \"split\": \"train\" } } - { \"type\": \"rows\", \"rows\": [ { \"messages\": [ {\"role\": \"user\", \"content\": \"Hello, world!\"}, {\"role\": \"assistant\", \"content\": \"Hello, world!\"}, ] } ] }"
|
"description": "The data source of the dataset. Examples: - { \"type\": \"uri\", \"uri\": \"https://mywebsite.com/mydata.jsonl\" } - { \"type\": \"uri\", \"uri\": \"lsfs://mydata.jsonl\" } - { \"type\": \"huggingface\", \"dataset_path\": \"tatsu-lab/alpaca\", \"params\": { \"split\": \"train\" } } - { \"type\": \"rows\", \"rows\": [ { \"messages\": [ {\"role\": \"user\", \"content\": \"Hello, world!\"}, {\"role\": \"assistant\", \"content\": \"Hello, world!\"}, ] } ] }"
|
||||||
},
|
},
|
||||||
|
@ -9554,8 +9602,8 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"schema",
|
"purpose",
|
||||||
"data_source"
|
"source"
|
||||||
],
|
],
|
||||||
"title": "RegisterDatasetRequest"
|
"title": "RegisterDatasetRequest"
|
||||||
},
|
},
|
||||||
|
@ -9769,21 +9817,11 @@
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"tool_responses": {
|
"tool_responses": {
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/ToolResponse"
|
"$ref": "#/components/schemas/ToolResponse"
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
"description": "The tool call responses to resume the turn with."
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/ToolResponseMessage"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"description": "The tool call responses to resume the turn with. NOTE: ToolResponseMessage will be deprecated. Use ToolResponse."
|
|
||||||
},
|
},
|
||||||
"stream": {
|
"stream": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
|
|
140
docs/_static/llama-stack-spec.yaml
vendored
140
docs/_static/llama-stack-spec.yaml
vendored
|
@ -3115,7 +3115,7 @@ components:
|
||||||
metrics:
|
metrics:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/MetricEvent'
|
$ref: '#/components/schemas/MetricInResponse'
|
||||||
completion_message:
|
completion_message:
|
||||||
$ref: '#/components/schemas/CompletionMessage'
|
$ref: '#/components/schemas/CompletionMessage'
|
||||||
description: The complete response message
|
description: The complete response message
|
||||||
|
@ -3130,29 +3130,9 @@ components:
|
||||||
- completion_message
|
- completion_message
|
||||||
title: ChatCompletionResponse
|
title: ChatCompletionResponse
|
||||||
description: Response from a chat completion request.
|
description: Response from a chat completion request.
|
||||||
MetricEvent:
|
MetricInResponse:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
trace_id:
|
|
||||||
type: string
|
|
||||||
span_id:
|
|
||||||
type: string
|
|
||||||
timestamp:
|
|
||||||
type: string
|
|
||||||
format: date-time
|
|
||||||
attributes:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: string
|
|
||||||
- type: integer
|
|
||||||
- type: number
|
|
||||||
- type: boolean
|
|
||||||
- type: 'null'
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
const: metric
|
|
||||||
default: metric
|
|
||||||
metric:
|
metric:
|
||||||
type: string
|
type: string
|
||||||
value:
|
value:
|
||||||
|
@ -3163,14 +3143,9 @@ components:
|
||||||
type: string
|
type: string
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- trace_id
|
|
||||||
- span_id
|
|
||||||
- timestamp
|
|
||||||
- type
|
|
||||||
- metric
|
- metric
|
||||||
- value
|
- value
|
||||||
- unit
|
title: MetricInResponse
|
||||||
title: MetricEvent
|
|
||||||
TokenLogProbs:
|
TokenLogProbs:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -3227,6 +3202,10 @@ components:
|
||||||
CompletionResponse:
|
CompletionResponse:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
metrics:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/MetricInResponse'
|
||||||
content:
|
content:
|
||||||
type: string
|
type: string
|
||||||
description: The generated completion text
|
description: The generated completion text
|
||||||
|
@ -3426,7 +3405,7 @@ components:
|
||||||
metrics:
|
metrics:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/MetricEvent'
|
$ref: '#/components/schemas/MetricInResponse'
|
||||||
event:
|
event:
|
||||||
$ref: '#/components/schemas/ChatCompletionResponseEvent'
|
$ref: '#/components/schemas/ChatCompletionResponseEvent'
|
||||||
description: The event containing the new content
|
description: The event containing the new content
|
||||||
|
@ -3545,6 +3524,10 @@ components:
|
||||||
CompletionResponseStreamChunk:
|
CompletionResponseStreamChunk:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
metrics:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/MetricInResponse'
|
||||||
delta:
|
delta:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
|
@ -5008,14 +4991,15 @@ components:
|
||||||
type: string
|
type: string
|
||||||
const: dataset
|
const: dataset
|
||||||
default: dataset
|
default: dataset
|
||||||
schema:
|
purpose:
|
||||||
type: string
|
type: string
|
||||||
enum:
|
enum:
|
||||||
- messages
|
- post-training/messages
|
||||||
title: Schema
|
- eval/question-answer
|
||||||
|
title: DatasetPurpose
|
||||||
description: >-
|
description: >-
|
||||||
Schema of the dataset. Each type has a different column format.
|
Purpose of the dataset. Each type has a different column format.
|
||||||
data_source:
|
source:
|
||||||
$ref: '#/components/schemas/DataSource'
|
$ref: '#/components/schemas/DataSource'
|
||||||
metadata:
|
metadata:
|
||||||
type: object
|
type: object
|
||||||
|
@ -5033,8 +5017,8 @@ components:
|
||||||
- provider_resource_id
|
- provider_resource_id
|
||||||
- provider_id
|
- provider_id
|
||||||
- type
|
- type
|
||||||
- schema
|
- purpose
|
||||||
- data_source
|
- source
|
||||||
- metadata
|
- metadata
|
||||||
title: Dataset
|
title: Dataset
|
||||||
HuggingfaceDataSource:
|
HuggingfaceDataSource:
|
||||||
|
@ -5044,8 +5028,10 @@ components:
|
||||||
type: string
|
type: string
|
||||||
const: huggingface
|
const: huggingface
|
||||||
default: huggingface
|
default: huggingface
|
||||||
dataset_path:
|
path:
|
||||||
type: string
|
type: string
|
||||||
|
description: >-
|
||||||
|
The path to the dataset in Huggingface. E.g. - "llamastack/simpleqa"
|
||||||
params:
|
params:
|
||||||
type: object
|
type: object
|
||||||
additionalProperties:
|
additionalProperties:
|
||||||
|
@ -5056,12 +5042,14 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
- type: object
|
||||||
|
description: The parameters for the dataset.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- type
|
- type
|
||||||
- dataset_path
|
- path
|
||||||
- params
|
- params
|
||||||
title: HuggingfaceDataSource
|
title: HuggingfaceDataSource
|
||||||
|
description: A dataset stored in Huggingface.
|
||||||
RowsDataSource:
|
RowsDataSource:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -5081,11 +5069,16 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
- type: object
|
||||||
|
description: >-
|
||||||
|
The dataset is stored in rows. E.g. - [ {"messages": [{"role": "user",
|
||||||
|
"content": "Hello, world!"}, {"role": "assistant", "content": "Hello,
|
||||||
|
world!"}]} ]
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- type
|
- type
|
||||||
- rows
|
- rows
|
||||||
title: RowsDataSource
|
title: RowsDataSource
|
||||||
|
description: A dataset stored in rows.
|
||||||
URIDataSource:
|
URIDataSource:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -5095,11 +5088,16 @@ components:
|
||||||
default: uri
|
default: uri
|
||||||
uri:
|
uri:
|
||||||
type: string
|
type: string
|
||||||
|
description: >-
|
||||||
|
The dataset can be obtained from a URI. E.g. - "https://mywebsite.com/mydata.jsonl"
|
||||||
|
- "lsfs://mydata.jsonl" - "data:csv;base64,{base64_content}"
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- type
|
- type
|
||||||
- uri
|
- uri
|
||||||
title: URIDataSource
|
title: URIDataSource
|
||||||
|
description: >-
|
||||||
|
A dataset that can be obtained from a URI.
|
||||||
Model:
|
Model:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -5920,6 +5918,47 @@ components:
|
||||||
- error
|
- error
|
||||||
- critical
|
- critical
|
||||||
title: LogSeverity
|
title: LogSeverity
|
||||||
|
MetricEvent:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
trace_id:
|
||||||
|
type: string
|
||||||
|
span_id:
|
||||||
|
type: string
|
||||||
|
timestamp:
|
||||||
|
type: string
|
||||||
|
format: date-time
|
||||||
|
attributes:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- type: integer
|
||||||
|
- type: number
|
||||||
|
- type: boolean
|
||||||
|
- type: 'null'
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: metric
|
||||||
|
default: metric
|
||||||
|
metric:
|
||||||
|
type: string
|
||||||
|
value:
|
||||||
|
oneOf:
|
||||||
|
- type: integer
|
||||||
|
- type: number
|
||||||
|
unit:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- trace_id
|
||||||
|
- span_id
|
||||||
|
- timestamp
|
||||||
|
- type
|
||||||
|
- metric
|
||||||
|
- value
|
||||||
|
- unit
|
||||||
|
title: MetricEvent
|
||||||
SpanEndPayload:
|
SpanEndPayload:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -6483,14 +6522,16 @@ components:
|
||||||
RegisterDatasetRequest:
|
RegisterDatasetRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
schema:
|
purpose:
|
||||||
type: string
|
type: string
|
||||||
enum:
|
enum:
|
||||||
- messages
|
- post-training/messages
|
||||||
|
- eval/question-answer
|
||||||
description: >-
|
description: >-
|
||||||
The schema format of the dataset. One of - messages: The dataset contains
|
The purpose of the dataset. One of - "post-training/messages": The dataset
|
||||||
a messages column with list of messages for post-training.
|
contains a messages column with list of messages for post-training. -
|
||||||
data_source:
|
"eval/question-answer": The dataset contains a question and answer column.
|
||||||
|
source:
|
||||||
$ref: '#/components/schemas/DataSource'
|
$ref: '#/components/schemas/DataSource'
|
||||||
description: >-
|
description: >-
|
||||||
The data source of the dataset. Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl"
|
The data source of the dataset. Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl"
|
||||||
|
@ -6517,8 +6558,8 @@ components:
|
||||||
The ID of the dataset. If not provided, a random ID will be generated.
|
The ID of the dataset. If not provided, a random ID will be generated.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- schema
|
- purpose
|
||||||
- data_source
|
- source
|
||||||
title: RegisterDatasetRequest
|
title: RegisterDatasetRequest
|
||||||
RegisterModelRequest:
|
RegisterModelRequest:
|
||||||
type: object
|
type: object
|
||||||
|
@ -6643,16 +6684,11 @@ components:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
tool_responses:
|
tool_responses:
|
||||||
oneOf:
|
type: array
|
||||||
- type: array
|
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/ToolResponse'
|
$ref: '#/components/schemas/ToolResponse'
|
||||||
- type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/ToolResponseMessage'
|
|
||||||
description: >-
|
description: >-
|
||||||
The tool call responses to resume the turn with. NOTE: ToolResponseMessage
|
The tool call responses to resume the turn with.
|
||||||
will be deprecated. Use ToolResponse.
|
|
||||||
stream:
|
stream:
|
||||||
type: boolean
|
type: boolean
|
||||||
description: Whether to stream the response.
|
description: Whether to stream the response.
|
||||||
|
|
|
@ -370,7 +370,7 @@ class AgentTurnResumeRequest(BaseModel):
|
||||||
agent_id: str
|
agent_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
turn_id: str
|
turn_id: str
|
||||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]]
|
tool_responses: List[ToolResponse]
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -449,7 +449,7 @@ class Agents(Protocol):
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
|
tool_responses: List[ToolResponse],
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||||
"""Resume an agent turn with executed tool call responses.
|
"""Resume an agent turn with executed tool call responses.
|
||||||
|
@ -460,7 +460,6 @@ class Agents(Protocol):
|
||||||
:param session_id: The ID of the session to resume.
|
:param session_id: The ID of the session to resume.
|
||||||
:param turn_id: The ID of the turn to resume.
|
:param turn_id: The ID of the turn to resume.
|
||||||
:param tool_responses: The tool call responses to resume the turn with.
|
:param tool_responses: The tool call responses to resume the turn with.
|
||||||
NOTE: ToolResponseMessage will be deprecated. Use ToolResponse.
|
|
||||||
:param stream: Whether to stream the response.
|
:param stream: Whether to stream the response.
|
||||||
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -13,10 +13,10 @@ from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
class Schema(Enum):
|
class DatasetPurpose(Enum):
|
||||||
"""
|
"""
|
||||||
Schema of the dataset. Each type has a different column format.
|
Purpose of the dataset. Each type has a different column format.
|
||||||
:cvar messages: The dataset contains messages used for post-training. Examples:
|
:cvar post-training/messages: The dataset contains messages used for post-training. Examples:
|
||||||
{
|
{
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "Hello, world!"},
|
{"role": "user", "content": "Hello, world!"},
|
||||||
|
@ -25,11 +25,19 @@ class Schema(Enum):
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages = "messages"
|
post_training_messages = "post-training/messages"
|
||||||
|
eval_question_answer = "eval/question-answer"
|
||||||
|
|
||||||
# TODO: add more schemas here
|
# TODO: add more schemas here
|
||||||
|
|
||||||
|
|
||||||
class DatasetType(Enum):
|
class DatasetType(Enum):
|
||||||
|
"""
|
||||||
|
Type of the dataset source.
|
||||||
|
:cvar huggingface: The dataset is stored in Huggingface.
|
||||||
|
:cvar uri: The dataset can be obtained from a URI.
|
||||||
|
:cvar rows: The dataset is stored in rows.
|
||||||
|
"""
|
||||||
huggingface = "huggingface"
|
huggingface = "huggingface"
|
||||||
uri = "uri"
|
uri = "uri"
|
||||||
rows = "rows"
|
rows = "rows"
|
||||||
|
@ -37,19 +45,36 @@ class DatasetType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class URIDataSource(BaseModel):
|
class URIDataSource(BaseModel):
|
||||||
|
"""A dataset that can be obtained from a URI.
|
||||||
|
:param uri: The dataset can be obtained from a URI. E.g.
|
||||||
|
- "https://mywebsite.com/mydata.jsonl"
|
||||||
|
- "lsfs://mydata.jsonl"
|
||||||
|
- "data:csv;base64,{base64_content}"
|
||||||
|
"""
|
||||||
type: Literal["uri"] = "uri"
|
type: Literal["uri"] = "uri"
|
||||||
uri: str
|
uri: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class HuggingfaceDataSource(BaseModel):
|
class HuggingfaceDataSource(BaseModel):
|
||||||
|
"""A dataset stored in Huggingface.
|
||||||
|
:param path: The path to the dataset in Huggingface. E.g.
|
||||||
|
- "llamastack/simpleqa"
|
||||||
|
:param params: The parameters for the dataset.
|
||||||
|
"""
|
||||||
type: Literal["huggingface"] = "huggingface"
|
type: Literal["huggingface"] = "huggingface"
|
||||||
dataset_path: str
|
path: str
|
||||||
params: Dict[str, Any]
|
params: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RowsDataSource(BaseModel):
|
class RowsDataSource(BaseModel):
|
||||||
|
"""A dataset stored in rows.
|
||||||
|
:param rows: The dataset is stored in rows. E.g.
|
||||||
|
- [
|
||||||
|
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
|
||||||
|
]
|
||||||
|
"""
|
||||||
type: Literal["rows"] = "rows"
|
type: Literal["rows"] = "rows"
|
||||||
rows: List[Dict[str, Any]]
|
rows: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
@ -64,8 +89,11 @@ DataSource = register_schema(
|
||||||
|
|
||||||
|
|
||||||
class CommonDatasetFields(BaseModel):
|
class CommonDatasetFields(BaseModel):
|
||||||
schema: Schema
|
"""
|
||||||
data_source: DataSource
|
Common fields for a dataset.
|
||||||
|
"""
|
||||||
|
purpose: DatasetPurpose
|
||||||
|
source: DataSource
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this dataset",
|
description="Any additional metadata for this dataset",
|
||||||
|
@ -99,17 +127,18 @@ class Datasets(Protocol):
|
||||||
@webmethod(route="/datasets", method="POST")
|
@webmethod(route="/datasets", method="POST")
|
||||||
async def register_dataset(
|
async def register_dataset(
|
||||||
self,
|
self,
|
||||||
schema: Schema,
|
purpose: DatasetPurpose,
|
||||||
data_source: DataSource,
|
source: DataSource,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
dataset_id: Optional[str] = None,
|
dataset_id: Optional[str] = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""
|
"""
|
||||||
Register a new dataset.
|
Register a new dataset.
|
||||||
|
|
||||||
:param schema: The schema format of the dataset. One of
|
:param purpose: The purpose of the dataset. One of
|
||||||
- messages: The dataset contains a messages column with list of messages for post-training.
|
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
||||||
:param data_source: The data source of the dataset. Examples:
|
- "eval/question-answer": The dataset contains a question and answer column.
|
||||||
|
:param source: The data source of the dataset. Examples:
|
||||||
- {
|
- {
|
||||||
"type": "uri",
|
"type": "uri",
|
||||||
"uri": "https://mywebsite.com/mydata.jsonl"
|
"uri": "https://mywebsite.com/mydata.jsonl"
|
||||||
|
|
|
@ -285,7 +285,7 @@ class CompletionRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(MetricResponseMixin):
|
||||||
"""Response from a completion request.
|
"""Response from a completion request.
|
||||||
|
|
||||||
:param content: The generated completion text
|
:param content: The generated completion text
|
||||||
|
@ -299,7 +299,7 @@ class CompletionResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionResponseStreamChunk(BaseModel):
|
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
"""A chunk of a streamed completion response.
|
"""A chunk of a streamed completion response.
|
||||||
|
|
||||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||||
|
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
"""A chunk of a streamed chat completion response.
|
"""A chunk of a streamed chat completion response.
|
||||||
|
|
||||||
:param event: The event containing the new content
|
:param event: The event containing the new content
|
||||||
|
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
class ChatCompletionResponse(MetricResponseMixin):
|
||||||
"""Response from a chat completion request.
|
"""Response from a chat completion request.
|
||||||
|
|
||||||
:param completion_message: The complete response message
|
:param completion_message: The complete response message
|
||||||
|
|
|
@ -96,6 +96,13 @@ class MetricEvent(EventCommon):
|
||||||
unit: str
|
unit: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MetricInResponse(BaseModel):
|
||||||
|
metric: str
|
||||||
|
value: Union[int, float]
|
||||||
|
unit: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# This is a short term solution to allow inference API to return metrics
|
# This is a short term solution to allow inference API to return metrics
|
||||||
# The ideal way to do this is to have a way for all response types to include metrics
|
# The ideal way to do this is to have a way for all response types to include metrics
|
||||||
# and all metric events logged to the telemetry API to be inlcuded with the response
|
# and all metric events logged to the telemetry API to be inlcuded with the response
|
||||||
|
@ -117,7 +124,7 @@ class MetricEvent(EventCommon):
|
||||||
|
|
||||||
|
|
||||||
class MetricResponseMixin(BaseModel):
|
class MetricResponseMixin(BaseModel):
|
||||||
metrics: Optional[List[MetricEvent]] = None
|
metrics: Optional[List[MetricInResponse]] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import (
|
||||||
preserve_headers_context_async_generator,
|
PROVIDER_DATA_VAR,
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
|
@ -44,8 +44,10 @@ from llama_stack.distribution.stack import (
|
||||||
redact_sensitive_fields,
|
redact_sensitive_fields,
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.distribution.utils.exec import in_notebook
|
from llama_stack.distribution.utils.exec import in_notebook
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
|
CURRENT_TRACE_CONTEXT,
|
||||||
end_trace,
|
end_trace,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
start_trace,
|
start_trace,
|
||||||
|
@ -384,8 +386,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
finally:
|
finally:
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
# Wrap the generator to preserve context across iterations
|
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||||
wrapped_gen = preserve_headers_context_async_generator(gen())
|
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
status_code=httpx.codes.OK,
|
status_code=httpx.codes.OK,
|
||||||
content=wrapped_gen,
|
content=wrapped_gen,
|
||||||
|
|
|
@ -7,14 +7,14 @@
|
||||||
import contextvars
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
|
from typing import Any, ContextManager, Dict, Optional
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Context variable for request provider data
|
# Context variable for request provider data
|
||||||
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
|
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||||
|
|
||||||
|
|
||||||
class RequestProviderDataContext(ContextManager):
|
class RequestProviderDataContext(ContextManager):
|
||||||
|
@ -26,40 +26,13 @@ class RequestProviderDataContext(ContextManager):
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# Save the current value and set the new one
|
# Save the current value and set the new one
|
||||||
self.token = _provider_data_var.set(self.provider_data)
|
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
# Restore the previous value
|
# Restore the previous value
|
||||||
if self.token is not None:
|
if self.token is not None:
|
||||||
_provider_data_var.reset(self.token)
|
PROVIDER_DATA_VAR.reset(self.token)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
|
||||||
"""
|
|
||||||
Wraps an async generator to preserve request headers context variables across iterations.
|
|
||||||
|
|
||||||
This ensures that context variables set during generator creation are
|
|
||||||
available during each iteration of the generator, even if the original
|
|
||||||
context manager has exited.
|
|
||||||
"""
|
|
||||||
# Capture the current context value right now
|
|
||||||
context_value = _provider_data_var.get()
|
|
||||||
|
|
||||||
async def wrapper():
|
|
||||||
while True:
|
|
||||||
# Set context before each anext() call
|
|
||||||
_ = _provider_data_var.set(context_value)
|
|
||||||
try:
|
|
||||||
item = await gen.__anext__()
|
|
||||||
yield item
|
|
||||||
except StopAsyncIteration:
|
|
||||||
break
|
|
||||||
|
|
||||||
return wrapper()
|
|
||||||
|
|
||||||
|
|
||||||
class NeedsRequestProviderData:
|
class NeedsRequestProviderData:
|
||||||
|
@ -72,7 +45,7 @@ class NeedsRequestProviderData:
|
||||||
if not validator_class:
|
if not validator_class:
|
||||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||||
|
|
||||||
val = _provider_data_var.get()
|
val = PROVIDER_DATA_VAR.get()
|
||||||
if not val:
|
if not val:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -165,7 +165,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
|
||||||
module="llama_stack.distribution.routers",
|
module="llama_stack.distribution.routers",
|
||||||
routing_table_api=info.routing_table_api,
|
routing_table_api=info.routing_table_api,
|
||||||
api_dependencies=[info.routing_table_api],
|
api_dependencies=[info.routing_table_api],
|
||||||
deps__=[info.routing_table_api.value],
|
# Add telemetry as an optional dependency to all auto-routed providers
|
||||||
|
optional_api_dependencies=[Api.telemetry],
|
||||||
|
deps__=([info.routing_table_api.value, Api.telemetry.value]),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||||
from .routers import (
|
from .routers import (
|
||||||
DatasetIORouter,
|
DatasetIORouter,
|
||||||
EvalRouter,
|
EvalRouter,
|
||||||
|
@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
||||||
"eval": EvalRouter,
|
"eval": EvalRouter,
|
||||||
"tool_runtime": ToolRuntimeRouter,
|
"tool_runtime": ToolRuntimeRouter,
|
||||||
}
|
}
|
||||||
|
api_to_deps = {
|
||||||
|
"inference": {"telemetry": Api.telemetry},
|
||||||
|
}
|
||||||
if api.value not in api_to_routers:
|
if api.value not in api_to_routers:
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table)
|
api_to_dep_impl = {}
|
||||||
|
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
|
||||||
|
if dep_api in deps:
|
||||||
|
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||||
|
|
||||||
|
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,7 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
import time
|
||||||
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
|
@ -20,6 +21,10 @@ from llama_stack.apis.eval import (
|
||||||
JobStatus,
|
JobStatus,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
CompletionMessage,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
|
@ -27,13 +32,14 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
StopReason,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.scoring import (
|
from llama_stack.apis.scoring import (
|
||||||
ScoreBatchResponse,
|
ScoreBatchResponse,
|
||||||
|
@ -42,6 +48,7 @@ from llama_stack.apis.scoring import (
|
||||||
ScoringFnParams,
|
ScoringFnParams,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
RAGDocument,
|
RAGDocument,
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
|
@ -52,7 +59,10 @@ from llama_stack.apis.tools import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
@ -119,9 +129,14 @@ class InferenceRouter(Inference):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
|
telemetry: Optional[Telemetry] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing InferenceRouter")
|
logger.debug("Initializing InferenceRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
self.telemetry = telemetry
|
||||||
|
if self.telemetry:
|
||||||
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug("InferenceRouter.initialize")
|
logger.debug("InferenceRouter.initialize")
|
||||||
|
@ -144,6 +159,71 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
|
|
||||||
|
def _construct_metrics(
|
||||||
|
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
|
||||||
|
) -> List[MetricEvent]:
|
||||||
|
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_tokens: Number of tokens in the prompt
|
||||||
|
completion_tokens: Number of tokens in the completion
|
||||||
|
total_tokens: Total number of tokens used
|
||||||
|
model: Model object containing model_id and provider_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of MetricEvent objects with token usage metrics
|
||||||
|
"""
|
||||||
|
span = get_current_span()
|
||||||
|
if span is None:
|
||||||
|
logger.warning("No span found for token usage metrics")
|
||||||
|
return []
|
||||||
|
metrics = [
|
||||||
|
("prompt_tokens", prompt_tokens),
|
||||||
|
("completion_tokens", completion_tokens),
|
||||||
|
("total_tokens", total_tokens),
|
||||||
|
]
|
||||||
|
metric_events = []
|
||||||
|
for metric_name, value in metrics:
|
||||||
|
metric_events.append(
|
||||||
|
MetricEvent(
|
||||||
|
trace_id=span.trace_id,
|
||||||
|
span_id=span.span_id,
|
||||||
|
metric=metric_name,
|
||||||
|
value=value,
|
||||||
|
timestamp=time.time(),
|
||||||
|
unit="tokens",
|
||||||
|
attributes={
|
||||||
|
"model_id": model.model_id,
|
||||||
|
"provider_id": model.provider_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return metric_events
|
||||||
|
|
||||||
|
async def _compute_and_log_token_usage(
|
||||||
|
self,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
total_tokens: int,
|
||||||
|
model: Model,
|
||||||
|
) -> List[MetricInResponse]:
|
||||||
|
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
|
if self.telemetry:
|
||||||
|
for metric in metrics:
|
||||||
|
await self.telemetry.log_event(metric)
|
||||||
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||||
|
|
||||||
|
async def _count_tokens(
|
||||||
|
self,
|
||||||
|
messages: List[Message] | InterleavedContent,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
) -> Optional[int]:
|
||||||
|
if isinstance(messages, list):
|
||||||
|
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||||
|
else:
|
||||||
|
encoded = self.formatter.encode_content(messages)
|
||||||
|
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -156,7 +236,7 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
)
|
)
|
||||||
|
@ -206,10 +286,47 @@ class InferenceRouter(Inference):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
|
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
|
||||||
|
async def stream_generator():
|
||||||
|
completion_text = ""
|
||||||
|
async for chunk in await provider.chat_completion(**params):
|
||||||
|
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||||
|
if chunk.event.delta.type == "text":
|
||||||
|
completion_text += chunk.event.delta.text
|
||||||
|
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
|
completion_tokens = await self._count_tokens(
|
||||||
|
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
|
||||||
|
tool_config.tool_prompt_format,
|
||||||
|
)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
metrics = await self._compute_and_log_token_usage(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return stream_generator()
|
||||||
else:
|
else:
|
||||||
return await provider.chat_completion(**params)
|
response = await provider.chat_completion(**params)
|
||||||
|
completion_tokens = await self._count_tokens(
|
||||||
|
[response.completion_message],
|
||||||
|
tool_config.tool_prompt_format,
|
||||||
|
)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
metrics = await self._compute_and_log_token_usage(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -239,10 +356,41 @@ class InferenceRouter(Inference):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt_tokens = await self._count_tokens(content)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in await provider.completion(**params))
|
|
||||||
|
async def stream_generator():
|
||||||
|
completion_text = ""
|
||||||
|
async for chunk in await provider.completion(**params):
|
||||||
|
if hasattr(chunk, "delta"):
|
||||||
|
completion_text += chunk.delta
|
||||||
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||||
|
completion_tokens = await self._count_tokens(completion_text)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
metrics = await self._compute_and_log_token_usage(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return stream_generator()
|
||||||
else:
|
else:
|
||||||
return await provider.completion(**params)
|
response = await provider.completion(**params)
|
||||||
|
completion_tokens = await self._count_tokens(response.content)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
metrics = await self._compute_and_log_token_usage(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
|
return response
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -28,7 +28,7 @@ from typing_extensions import Annotated
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import (
|
||||||
preserve_headers_context_async_generator,
|
PROVIDER_DATA_VAR,
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
|
@ -38,6 +38,7 @@ from llama_stack.distribution.stack import (
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
validate_env_pair,
|
validate_env_pair,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||||
|
@ -45,6 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||||
TelemetryAdapter,
|
TelemetryAdapter,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
|
CURRENT_TRACE_CONTEXT,
|
||||||
end_trace,
|
end_trace,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
start_trace,
|
start_trace,
|
||||||
|
@ -182,7 +184,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
|
gen = preserve_contexts_async_generator(
|
||||||
|
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||||
|
)
|
||||||
return StreamingResponse(gen, media_type="text/event-stream")
|
return StreamingResponse(gen, media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
value = func(**kwargs)
|
value = func(**kwargs)
|
||||||
|
|
33
llama_stack/distribution/utils/context.py
Normal file
33
llama_stack/distribution/utils/context.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import AsyncGenerator, List, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def preserve_contexts_async_generator(
|
||||||
|
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
||||||
|
) -> AsyncGenerator[T, None]:
|
||||||
|
"""
|
||||||
|
Wraps an async generator to preserve context variables across iterations.
|
||||||
|
This is needed because we start a new asyncio event loop for each streaming request,
|
||||||
|
and we need to preserve the context across the event loop boundary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def wrapper() -> AsyncGenerator[T, None]:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
item = await gen.__anext__()
|
||||||
|
context_values = {context_var.name: context_var.get() for context_var in context_vars}
|
||||||
|
yield item
|
||||||
|
for context_var in context_vars:
|
||||||
|
_ = context_var.set(context_values[context_var.name])
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
return wrapper()
|
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserve_contexts_with_exception():
|
||||||
|
# Create context variable
|
||||||
|
context_var = ContextVar("exception_var", default="initial")
|
||||||
|
token = context_var.set("start_value")
|
||||||
|
|
||||||
|
# Create an async generator that raises an exception
|
||||||
|
async def exception_generator():
|
||||||
|
yield context_var.get()
|
||||||
|
context_var.set("modified")
|
||||||
|
raise ValueError("Test exception")
|
||||||
|
yield None # This will never be reached
|
||||||
|
|
||||||
|
# Wrap the generator
|
||||||
|
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
|
||||||
|
|
||||||
|
# First iteration should work
|
||||||
|
value = await wrapped_gen.__anext__()
|
||||||
|
assert value == "start_value"
|
||||||
|
|
||||||
|
# Second iteration should raise the exception
|
||||||
|
with pytest.raises(ValueError, match="Test exception"):
|
||||||
|
await wrapped_gen.__anext__()
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
context_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserve_contexts_empty_generator():
|
||||||
|
# Create context variable
|
||||||
|
context_var = ContextVar("empty_var", default="initial")
|
||||||
|
token = context_var.set("value")
|
||||||
|
|
||||||
|
# Create an empty async generator
|
||||||
|
async def empty_generator():
|
||||||
|
if False: # This condition ensures the generator yields nothing
|
||||||
|
yield None
|
||||||
|
|
||||||
|
# Wrap the generator
|
||||||
|
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
|
||||||
|
|
||||||
|
# The generator should raise StopAsyncIteration immediately
|
||||||
|
with pytest.raises(StopAsyncIteration):
|
||||||
|
await wrapped_gen.__anext__()
|
||||||
|
|
||||||
|
# Context variable should remain unchanged
|
||||||
|
assert context_var.get() == "value"
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
context_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserve_contexts_across_event_loops():
|
||||||
|
"""
|
||||||
|
Test that context variables are preserved across event loop boundaries with nested generators.
|
||||||
|
This simulates the real-world scenario where:
|
||||||
|
1. A new event loop is created for each streaming request
|
||||||
|
2. The async generator runs inside that loop
|
||||||
|
3. There are multiple levels of nested generators
|
||||||
|
4. Context needs to be preserved across these boundaries
|
||||||
|
"""
|
||||||
|
# Create context variables
|
||||||
|
request_id = ContextVar("request_id", default=None)
|
||||||
|
user_id = ContextVar("user_id", default=None)
|
||||||
|
|
||||||
|
# Set initial values
|
||||||
|
|
||||||
|
# Results container to verify values across thread boundaries
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Inner-most generator (level 2)
|
||||||
|
async def inner_generator():
|
||||||
|
# Should have the context from the outer scope
|
||||||
|
yield (1, request_id.get(), user_id.get())
|
||||||
|
|
||||||
|
# Modify one context variable
|
||||||
|
user_id.set("user-modified")
|
||||||
|
|
||||||
|
# Should reflect the modification
|
||||||
|
yield (2, request_id.get(), user_id.get())
|
||||||
|
|
||||||
|
# Middle generator (level 1)
|
||||||
|
async def middle_generator():
|
||||||
|
inner_gen = inner_generator()
|
||||||
|
|
||||||
|
# Forward the first yield from inner
|
||||||
|
item = await inner_gen.__anext__()
|
||||||
|
yield item
|
||||||
|
|
||||||
|
# Forward the second yield from inner
|
||||||
|
item = await inner_gen.__anext__()
|
||||||
|
yield item
|
||||||
|
|
||||||
|
request_id.set("req-modified")
|
||||||
|
|
||||||
|
# Add our own yield with both modified variables
|
||||||
|
yield (3, request_id.get(), user_id.get())
|
||||||
|
|
||||||
|
# Function to run in a separate thread with a new event loop
|
||||||
|
def run_in_new_loop():
|
||||||
|
# Create a new event loop for this thread
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Outer generator (runs in the new loop)
|
||||||
|
async def outer_generator():
|
||||||
|
request_id.set("req-12345")
|
||||||
|
user_id.set("user-6789")
|
||||||
|
# Wrap the middle generator
|
||||||
|
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
|
||||||
|
|
||||||
|
# Process all items from the middle generator
|
||||||
|
async for item in wrapped_gen:
|
||||||
|
# Store results for verification
|
||||||
|
results.append(item)
|
||||||
|
|
||||||
|
# Run the outer generator in the new loop
|
||||||
|
loop.run_until_complete(outer_generator())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
# Run the generator chain in a separate thread with a new event loop
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(run_in_new_loop)
|
||||||
|
future.result() # Wait for completion
|
||||||
|
|
||||||
|
# Verify the results
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
# First yield should have original values
|
||||||
|
assert results[0] == (1, "req-12345", "user-6789")
|
||||||
|
|
||||||
|
# Second yield should have modified user_id
|
||||||
|
assert results[1] == (2, "req-12345", "user-modified")
|
||||||
|
|
||||||
|
# Third yield should have both modified values
|
||||||
|
assert results[2] == (3, "req-modified", "user-modified")
|
|
@ -218,18 +218,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
steps = []
|
steps = []
|
||||||
messages = await self.get_messages_from_turns(turns)
|
messages = await self.get_messages_from_turns(turns)
|
||||||
if is_resume:
|
if is_resume:
|
||||||
if isinstance(request.tool_responses[0], ToolResponseMessage):
|
|
||||||
tool_response_messages = request.tool_responses
|
|
||||||
tool_responses = [
|
|
||||||
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
|
||||||
for x in request.tool_responses
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
tool_response_messages = [
|
tool_response_messages = [
|
||||||
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||||
for x in request.tool_responses
|
for x in request.tool_responses
|
||||||
]
|
]
|
||||||
tool_responses = request.tool_responses
|
|
||||||
messages.extend(tool_response_messages)
|
messages.extend(tool_response_messages)
|
||||||
last_turn = turns[-1]
|
last_turn = turns[-1]
|
||||||
last_turn_messages = self.turn_to_messages(last_turn)
|
last_turn_messages = self.turn_to_messages(last_turn)
|
||||||
|
@ -252,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||||
turn_id=request.turn_id,
|
turn_id=request.turn_id,
|
||||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||||
tool_responses=tool_responses,
|
tool_responses=request.tool_responses,
|
||||||
completed_at=now,
|
completed_at=now,
|
||||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||||
)
|
)
|
||||||
|
|
|
@ -172,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
|
tool_responses: List[ToolResponse],
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnResumeRequest(
|
request = AgentTurnResumeRequest(
|
||||||
|
|
|
@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = deps.get(Api.datasetio)
|
self.datasetio_api = deps.get(Api.datasetio)
|
||||||
|
self.meter = None
|
||||||
|
|
||||||
resource = Resource.create(
|
resource = Resource.create(
|
||||||
{
|
{
|
||||||
|
@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
return _GLOBAL_STORAGE["gauges"][name]
|
return _GLOBAL_STORAGE["gauges"][name]
|
||||||
|
|
||||||
def _log_metric(self, event: MetricEvent) -> None:
|
def _log_metric(self, event: MetricEvent) -> None:
|
||||||
|
if self.meter is None:
|
||||||
|
return
|
||||||
if isinstance(event.value, int):
|
if isinstance(event.value, int):
|
||||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||||
counter.add(event.value, attributes=event.attributes)
|
counter.add(event.value, attributes=event.attributes)
|
||||||
|
|
|
@ -4,8 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.models.models import ModelType
|
from llama_stack.apis.models.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
BenchmarkInput,
|
BenchmarkInput,
|
||||||
|
@ -15,21 +16,27 @@ from llama_stack.distribution.datatypes import (
|
||||||
ShieldInput,
|
ShieldInput,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
||||||
|
SQLiteVectorIOConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
||||||
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
||||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
PGVectorVectorIOConfig,
|
||||||
ProviderModelEntry,
|
)
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||||
|
from llama_stack.templates.template import (
|
||||||
|
DistributionTemplate,
|
||||||
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
)
|
)
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
|
||||||
|
|
||||||
|
|
||||||
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]:
|
||||||
# in this template, we allow each API key to be optional
|
# in this template, we allow each API key to be optional
|
||||||
providers = [
|
providers = [
|
||||||
(
|
(
|
||||||
|
@ -164,7 +171,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
DatasetInput(
|
DatasetInput(
|
||||||
dataset_id="simpleqa",
|
dataset_id="simpleqa",
|
||||||
provider_id="huggingface",
|
provider_id="huggingface",
|
||||||
url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"},
|
url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"),
|
||||||
metadata={
|
metadata={
|
||||||
"path": "llamastack/simpleqa",
|
"path": "llamastack/simpleqa",
|
||||||
"split": "train",
|
"split": "train",
|
||||||
|
@ -178,7 +185,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
DatasetInput(
|
DatasetInput(
|
||||||
dataset_id="mmlu_cot",
|
dataset_id="mmlu_cot",
|
||||||
provider_id="huggingface",
|
provider_id="huggingface",
|
||||||
url={"uri": "https://huggingface.co/datasets/llamastack/mmlu_cot"},
|
url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"),
|
||||||
metadata={
|
metadata={
|
||||||
"path": "llamastack/mmlu_cot",
|
"path": "llamastack/mmlu_cot",
|
||||||
"name": "all",
|
"name": "all",
|
||||||
|
@ -193,7 +200,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
DatasetInput(
|
DatasetInput(
|
||||||
dataset_id="gpqa_cot",
|
dataset_id="gpqa_cot",
|
||||||
provider_id="huggingface",
|
provider_id="huggingface",
|
||||||
url={"uri": "https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"},
|
url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"),
|
||||||
metadata={
|
metadata={
|
||||||
"path": "llamastack/gpqa_0shot_cot",
|
"path": "llamastack/gpqa_0shot_cot",
|
||||||
"name": "gpqa_main",
|
"name": "gpqa_main",
|
||||||
|
@ -208,7 +215,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
DatasetInput(
|
DatasetInput(
|
||||||
dataset_id="math_500",
|
dataset_id="math_500",
|
||||||
provider_id="huggingface",
|
provider_id="huggingface",
|
||||||
url={"uri": "https://huggingface.co/datasets/llamastack/math_500"},
|
url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"),
|
||||||
metadata={
|
metadata={
|
||||||
"path": "llamastack/math_500",
|
"path": "llamastack/math_500",
|
||||||
"split": "test",
|
"split": "test",
|
||||||
|
|
|
@ -30,7 +30,9 @@ from llama_stack.providers.utils.inference.model_registry import ProviderModelEn
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
def get_model_registry(available_models: Dict[str, List[ProviderModelEntry]]) -> List[ModelInput]:
|
def get_model_registry(
|
||||||
|
available_models: Dict[str, List[ProviderModelEntry]],
|
||||||
|
) -> List[ModelInput]:
|
||||||
models = []
|
models = []
|
||||||
for provider_id, entries in available_models.items():
|
for provider_id, entries in available_models.items():
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
|
@ -193,7 +195,7 @@ class DistributionTemplate(BaseModel):
|
||||||
default_models.append(
|
default_models.append(
|
||||||
DefaultModel(
|
DefaultModel(
|
||||||
model_id=model_entry.provider_model_id,
|
model_id=model_entry.provider_model_id,
|
||||||
doc_string=f"({' -- '.join(doc_parts)})" if doc_parts else "",
|
doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue