chore: more mypy fixes (#2029)

# What does this PR do?

Mainly tried to cover the entire llama_stack/apis directory, we only
have one left. Some excludes were just noop.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-05-06 18:52:31 +02:00 committed by GitHub
parent feb9eb8b0d
commit 1a529705da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 581 additions and 166 deletions

View file

@ -4052,9 +4052,13 @@
"properties": { "properties": {
"type": { "type": {
"type": "string", "type": "string",
"enum": [
"json_schema",
"grammar"
],
"description": "Must be \"grammar\" to identify this format type",
"const": "grammar", "const": "grammar",
"default": "grammar", "default": "grammar"
"description": "Must be \"grammar\" to identify this format type"
}, },
"bnf": { "bnf": {
"type": "object", "type": "object",
@ -4178,9 +4182,13 @@
"properties": { "properties": {
"type": { "type": {
"type": "string", "type": "string",
"enum": [
"json_schema",
"grammar"
],
"description": "Must be \"json_schema\" to identify this format type",
"const": "json_schema", "const": "json_schema",
"default": "json_schema", "default": "json_schema"
"description": "Must be \"json_schema\" to identify this format type"
}, },
"json_schema": { "json_schema": {
"type": "object", "type": "object",
@ -5638,6 +5646,14 @@
}, },
"step_type": { "step_type": {
"type": "string", "type": "string",
"enum": [
"inference",
"tool_execution",
"shield_call",
"memory_retrieval"
],
"title": "StepType",
"description": "Type of the step in an agent turn.",
"const": "inference", "const": "inference",
"default": "inference" "default": "inference"
}, },
@ -5679,6 +5695,14 @@
}, },
"step_type": { "step_type": {
"type": "string", "type": "string",
"enum": [
"inference",
"tool_execution",
"shield_call",
"memory_retrieval"
],
"title": "StepType",
"description": "Type of the step in an agent turn.",
"const": "memory_retrieval", "const": "memory_retrieval",
"default": "memory_retrieval" "default": "memory_retrieval"
}, },
@ -5767,6 +5791,14 @@
}, },
"step_type": { "step_type": {
"type": "string", "type": "string",
"enum": [
"inference",
"tool_execution",
"shield_call",
"memory_retrieval"
],
"title": "StepType",
"description": "Type of the step in an agent turn.",
"const": "shield_call", "const": "shield_call",
"default": "shield_call" "default": "shield_call"
}, },
@ -5807,6 +5839,14 @@
}, },
"step_type": { "step_type": {
"type": "string", "type": "string",
"enum": [
"inference",
"tool_execution",
"shield_call",
"memory_retrieval"
],
"title": "StepType",
"description": "Type of the step in an agent turn.",
"const": "tool_execution", "const": "tool_execution",
"default": "tool_execution" "default": "tool_execution"
}, },
@ -6069,6 +6109,15 @@
"properties": { "properties": {
"event_type": { "event_type": {
"type": "string", "type": "string",
"enum": [
"step_start",
"step_complete",
"step_progress",
"turn_start",
"turn_complete",
"turn_awaiting_input"
],
"title": "AgentTurnResponseEventType",
"const": "step_complete", "const": "step_complete",
"default": "step_complete" "default": "step_complete"
}, },
@ -6126,6 +6175,15 @@
"properties": { "properties": {
"event_type": { "event_type": {
"type": "string", "type": "string",
"enum": [
"step_start",
"step_complete",
"step_progress",
"turn_start",
"turn_complete",
"turn_awaiting_input"
],
"title": "AgentTurnResponseEventType",
"const": "step_progress", "const": "step_progress",
"default": "step_progress" "default": "step_progress"
}, },
@ -6161,6 +6219,15 @@
"properties": { "properties": {
"event_type": { "event_type": {
"type": "string", "type": "string",
"enum": [
"step_start",
"step_complete",
"step_progress",
"turn_start",
"turn_complete",
"turn_awaiting_input"
],
"title": "AgentTurnResponseEventType",
"const": "step_start", "const": "step_start",
"default": "step_start" "default": "step_start"
}, },
@ -6231,6 +6298,15 @@
"properties": { "properties": {
"event_type": { "event_type": {
"type": "string", "type": "string",
"enum": [
"step_start",
"step_complete",
"step_progress",
"turn_start",
"turn_complete",
"turn_awaiting_input"
],
"title": "AgentTurnResponseEventType",
"const": "turn_awaiting_input", "const": "turn_awaiting_input",
"default": "turn_awaiting_input" "default": "turn_awaiting_input"
}, },
@ -6250,6 +6326,15 @@
"properties": { "properties": {
"event_type": { "event_type": {
"type": "string", "type": "string",
"enum": [
"step_start",
"step_complete",
"step_progress",
"turn_start",
"turn_complete",
"turn_awaiting_input"
],
"title": "AgentTurnResponseEventType",
"const": "turn_complete", "const": "turn_complete",
"default": "turn_complete" "default": "turn_complete"
}, },
@ -6269,6 +6354,15 @@
"properties": { "properties": {
"event_type": { "event_type": {
"type": "string", "type": "string",
"enum": [
"step_start",
"step_complete",
"step_progress",
"turn_start",
"turn_complete",
"turn_awaiting_input"
],
"title": "AgentTurnResponseEventType",
"const": "turn_start", "const": "turn_start",
"default": "turn_start" "default": "turn_start"
}, },
@ -6876,7 +6970,7 @@
"type": "object", "type": "object",
"properties": { "properties": {
"type": { "type": {
"type": "string", "$ref": "#/components/schemas/ScoringFnParamsType",
"const": "basic", "const": "basic",
"default": "basic" "default": "basic"
}, },
@ -6889,7 +6983,8 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type" "type",
"aggregation_functions"
], ],
"title": "BasicScoringFnParams" "title": "BasicScoringFnParams"
}, },
@ -6941,7 +7036,7 @@
"type": "object", "type": "object",
"properties": { "properties": {
"type": { "type": {
"type": "string", "$ref": "#/components/schemas/ScoringFnParamsType",
"const": "llm_as_judge", "const": "llm_as_judge",
"default": "llm_as_judge" "default": "llm_as_judge"
}, },
@ -6967,7 +7062,9 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type", "type",
"judge_model" "judge_model",
"judge_score_regexes",
"aggregation_functions"
], ],
"title": "LLMAsJudgeScoringFnParams" "title": "LLMAsJudgeScoringFnParams"
}, },
@ -7005,7 +7102,7 @@
"type": "object", "type": "object",
"properties": { "properties": {
"type": { "type": {
"type": "string", "$ref": "#/components/schemas/ScoringFnParamsType",
"const": "regex_parser", "const": "regex_parser",
"default": "regex_parser" "default": "regex_parser"
}, },
@ -7024,7 +7121,9 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type" "type",
"parsing_regexes",
"aggregation_functions"
], ],
"title": "RegexParserScoringFnParams" "title": "RegexParserScoringFnParams"
}, },
@ -7049,6 +7148,15 @@
} }
} }
}, },
"ScoringFnParamsType": {
"type": "string",
"enum": [
"llm_as_judge",
"regex_parser",
"basic"
],
"title": "ScoringFnParamsType"
},
"EvaluateRowsRequest": { "EvaluateRowsRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -7317,6 +7425,17 @@
}, },
"type": { "type": {
"type": "string", "type": "string",
"enum": [
"model",
"shield",
"vector_db",
"dataset",
"scoring_function",
"benchmark",
"tool",
"tool_group"
],
"title": "ResourceType",
"const": "benchmark", "const": "benchmark",
"default": "benchmark" "default": "benchmark"
}, },
@ -7358,7 +7477,6 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"identifier", "identifier",
"provider_resource_id",
"provider_id", "provider_id",
"type", "type",
"dataset_id", "dataset_id",
@ -7398,6 +7516,17 @@
}, },
"type": { "type": {
"type": "string", "type": "string",
"enum": [
"model",
"shield",
"vector_db",
"dataset",
"scoring_function",
"benchmark",
"tool",
"tool_group"
],
"title": "ResourceType",
"const": "dataset", "const": "dataset",
"default": "dataset" "default": "dataset"
}, },
@ -7443,7 +7572,6 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"identifier", "identifier",
"provider_resource_id",
"provider_id", "provider_id",
"type", "type",
"purpose", "purpose",
@ -7573,6 +7701,17 @@
}, },
"type": { "type": {
"type": "string", "type": "string",
"enum": [
"model",
"shield",
"vector_db",
"dataset",
"scoring_function",
"benchmark",
"tool",
"tool_group"
],
"title": "ResourceType",
"const": "model", "const": "model",
"default": "model" "default": "model"
}, },
@ -7609,7 +7748,6 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"identifier", "identifier",
"provider_resource_id",
"provider_id", "provider_id",
"type", "type",
"metadata", "metadata",
@ -7808,6 +7946,17 @@
}, },
"type": { "type": {
"type": "string", "type": "string",
"enum": [
"model",
"shield",
"vector_db",
"dataset",
"scoring_function",
"benchmark",
"tool",
"tool_group"
],
"title": "ResourceType",
"const": "scoring_function", "const": "scoring_function",
"default": "scoring_function" "default": "scoring_function"
}, },
@ -7849,7 +7998,6 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"identifier", "identifier",
"provider_resource_id",
"provider_id", "provider_id",
"type", "type",
"metadata", "metadata",
@ -7901,6 +8049,17 @@
}, },
"type": { "type": {
"type": "string", "type": "string",
"enum": [
"model",
"shield",
"vector_db",
"dataset",
"scoring_function",
"benchmark",
"tool",
"tool_group"
],
"title": "ResourceType",
"const": "shield", "const": "shield",
"default": "shield" "default": "shield"
}, },
@ -7933,7 +8092,6 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"identifier", "identifier",
"provider_resource_id",
"provider_id", "provider_id",
"type" "type"
], ],
@ -8113,6 +8271,17 @@
}, },
"type": { "type": {
"type": "string", "type": "string",
"enum": [
"model",
"shield",
"vector_db",
"dataset",
"scoring_function",
"benchmark",
"tool",
"tool_group"
],
"title": "ResourceType",
"const": "tool", "const": "tool",
"default": "tool" "default": "tool"
}, },
@ -8160,7 +8329,6 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"identifier", "identifier",
"provider_resource_id",
"provider_id", "provider_id",
"type", "type",
"toolgroup_id", "toolgroup_id",
@ -8193,6 +8361,17 @@
}, },
"type": { "type": {
"type": "string", "type": "string",
"enum": [
"model",
"shield",
"vector_db",
"dataset",
"scoring_function",
"benchmark",
"tool",
"tool_group"
],
"title": "ResourceType",
"const": "tool_group", "const": "tool_group",
"default": "tool_group" "default": "tool_group"
}, },
@ -8228,7 +8407,6 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"identifier", "identifier",
"provider_resource_id",
"provider_id", "provider_id",
"type" "type"
], ],
@ -8395,6 +8573,17 @@
}, },
"type": { "type": {
"type": "string", "type": "string",
"enum": [
"model",
"shield",
"vector_db",
"dataset",
"scoring_function",
"benchmark",
"tool",
"tool_group"
],
"title": "ResourceType",
"const": "vector_db", "const": "vector_db",
"default": "vector_db" "default": "vector_db"
}, },
@ -8408,7 +8597,6 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"identifier", "identifier",
"provider_resource_id",
"provider_id", "provider_id",
"type", "type",
"embedding_model", "embedding_model",
@ -9110,6 +9298,15 @@
} }
} }
}, },
"EventType": {
"type": "string",
"enum": [
"unstructured_log",
"structured_log",
"metric"
],
"title": "EventType"
},
"LogSeverity": { "LogSeverity": {
"type": "string", "type": "string",
"enum": [ "enum": [
@ -9158,7 +9355,7 @@
} }
}, },
"type": { "type": {
"type": "string", "$ref": "#/components/schemas/EventType",
"const": "metric", "const": "metric",
"default": "metric" "default": "metric"
}, },
@ -9195,7 +9392,7 @@
"type": "object", "type": "object",
"properties": { "properties": {
"type": { "type": {
"type": "string", "$ref": "#/components/schemas/StructuredLogType",
"const": "span_end", "const": "span_end",
"default": "span_end" "default": "span_end"
}, },
@ -9214,7 +9411,7 @@
"type": "object", "type": "object",
"properties": { "properties": {
"type": { "type": {
"type": "string", "$ref": "#/components/schemas/StructuredLogType",
"const": "span_start", "const": "span_start",
"default": "span_start" "default": "span_start"
}, },
@ -9268,7 +9465,7 @@
} }
}, },
"type": { "type": {
"type": "string", "$ref": "#/components/schemas/EventType",
"const": "structured_log", "const": "structured_log",
"default": "structured_log" "default": "structured_log"
}, },
@ -9303,6 +9500,14 @@
} }
} }
}, },
"StructuredLogType": {
"type": "string",
"enum": [
"span_start",
"span_end"
],
"title": "StructuredLogType"
},
"UnstructuredLogEvent": { "UnstructuredLogEvent": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9339,7 +9544,7 @@
} }
}, },
"type": { "type": {
"type": "string", "$ref": "#/components/schemas/EventType",
"const": "unstructured_log", "const": "unstructured_log",
"default": "unstructured_log" "default": "unstructured_log"
}, },

View file

@ -2812,10 +2812,13 @@ components:
properties: properties:
type: type:
type: string type: string
const: grammar enum:
default: grammar - json_schema
- grammar
description: >- description: >-
Must be "grammar" to identify this format type Must be "grammar" to identify this format type
const: grammar
default: grammar
bnf: bnf:
type: object type: object
additionalProperties: additionalProperties:
@ -2897,10 +2900,13 @@ components:
properties: properties:
type: type:
type: string type: string
const: json_schema enum:
default: json_schema - json_schema
- grammar
description: >- description: >-
Must be "json_schema" to identify this format type Must be "json_schema" to identify this format type
const: json_schema
default: json_schema
json_schema: json_schema:
type: object type: object
additionalProperties: additionalProperties:
@ -3959,6 +3965,13 @@ components:
description: The time the step completed. description: The time the step completed.
step_type: step_type:
type: string type: string
enum:
- inference
- tool_execution
- shield_call
- memory_retrieval
title: StepType
description: Type of the step in an agent turn.
const: inference const: inference
default: inference default: inference
model_response: model_response:
@ -3991,6 +4004,13 @@ components:
description: The time the step completed. description: The time the step completed.
step_type: step_type:
type: string type: string
enum:
- inference
- tool_execution
- shield_call
- memory_retrieval
title: StepType
description: Type of the step in an agent turn.
const: memory_retrieval const: memory_retrieval
default: memory_retrieval default: memory_retrieval
vector_db_ids: vector_db_ids:
@ -4052,6 +4072,13 @@ components:
description: The time the step completed. description: The time the step completed.
step_type: step_type:
type: string type: string
enum:
- inference
- tool_execution
- shield_call
- memory_retrieval
title: StepType
description: Type of the step in an agent turn.
const: shield_call const: shield_call
default: shield_call default: shield_call
violation: violation:
@ -4083,6 +4110,13 @@ components:
description: The time the step completed. description: The time the step completed.
step_type: step_type:
type: string type: string
enum:
- inference
- tool_execution
- shield_call
- memory_retrieval
title: StepType
description: Type of the step in an agent turn.
const: tool_execution const: tool_execution
default: tool_execution default: tool_execution
tool_calls: tool_calls:
@ -4245,6 +4279,14 @@ components:
properties: properties:
event_type: event_type:
type: string type: string
enum:
- step_start
- step_complete
- step_progress
- turn_start
- turn_complete
- turn_awaiting_input
title: AgentTurnResponseEventType
const: step_complete const: step_complete
default: step_complete default: step_complete
step_type: step_type:
@ -4283,6 +4325,14 @@ components:
properties: properties:
event_type: event_type:
type: string type: string
enum:
- step_start
- step_complete
- step_progress
- turn_start
- turn_complete
- turn_awaiting_input
title: AgentTurnResponseEventType
const: step_progress const: step_progress
default: step_progress default: step_progress
step_type: step_type:
@ -4310,6 +4360,14 @@ components:
properties: properties:
event_type: event_type:
type: string type: string
enum:
- step_start
- step_complete
- step_progress
- turn_start
- turn_complete
- turn_awaiting_input
title: AgentTurnResponseEventType
const: step_start const: step_start
default: step_start default: step_start
step_type: step_type:
@ -4354,6 +4412,14 @@ components:
properties: properties:
event_type: event_type:
type: string type: string
enum:
- step_start
- step_complete
- step_progress
- turn_start
- turn_complete
- turn_awaiting_input
title: AgentTurnResponseEventType
const: turn_awaiting_input const: turn_awaiting_input
default: turn_awaiting_input default: turn_awaiting_input
turn: turn:
@ -4369,6 +4435,14 @@ components:
properties: properties:
event_type: event_type:
type: string type: string
enum:
- step_start
- step_complete
- step_progress
- turn_start
- turn_complete
- turn_awaiting_input
title: AgentTurnResponseEventType
const: turn_complete const: turn_complete
default: turn_complete default: turn_complete
turn: turn:
@ -4383,6 +4457,14 @@ components:
properties: properties:
event_type: event_type:
type: string type: string
enum:
- step_start
- step_complete
- step_progress
- turn_start
- turn_complete
- turn_awaiting_input
title: AgentTurnResponseEventType
const: turn_start const: turn_start
default: turn_start default: turn_start
turn_id: turn_id:
@ -4825,7 +4907,7 @@ components:
type: object type: object
properties: properties:
type: type:
type: string $ref: '#/components/schemas/ScoringFnParamsType'
const: basic const: basic
default: basic default: basic
aggregation_functions: aggregation_functions:
@ -4835,6 +4917,7 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- type - type
- aggregation_functions
title: BasicScoringFnParams title: BasicScoringFnParams
BenchmarkConfig: BenchmarkConfig:
type: object type: object
@ -4874,7 +4957,7 @@ components:
type: object type: object
properties: properties:
type: type:
type: string $ref: '#/components/schemas/ScoringFnParamsType'
const: llm_as_judge const: llm_as_judge
default: llm_as_judge default: llm_as_judge
judge_model: judge_model:
@ -4893,6 +4976,8 @@ components:
required: required:
- type - type
- judge_model - judge_model
- judge_score_regexes
- aggregation_functions
title: LLMAsJudgeScoringFnParams title: LLMAsJudgeScoringFnParams
ModelCandidate: ModelCandidate:
type: object type: object
@ -4923,7 +5008,7 @@ components:
type: object type: object
properties: properties:
type: type:
type: string $ref: '#/components/schemas/ScoringFnParamsType'
const: regex_parser const: regex_parser
default: regex_parser default: regex_parser
parsing_regexes: parsing_regexes:
@ -4937,6 +5022,8 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- type - type
- parsing_regexes
- aggregation_functions
title: RegexParserScoringFnParams title: RegexParserScoringFnParams
ScoringFnParams: ScoringFnParams:
oneOf: oneOf:
@ -4949,6 +5036,13 @@ components:
llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams' llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams'
regex_parser: '#/components/schemas/RegexParserScoringFnParams' regex_parser: '#/components/schemas/RegexParserScoringFnParams'
basic: '#/components/schemas/BasicScoringFnParams' basic: '#/components/schemas/BasicScoringFnParams'
ScoringFnParamsType:
type: string
enum:
- llm_as_judge
- regex_parser
- basic
title: ScoringFnParamsType
EvaluateRowsRequest: EvaluateRowsRequest:
type: object type: object
properties: properties:
@ -5111,6 +5205,16 @@ components:
type: string type: string
type: type:
type: string type: string
enum:
- model
- shield
- vector_db
- dataset
- scoring_function
- benchmark
- tool
- tool_group
title: ResourceType
const: benchmark const: benchmark
default: benchmark default: benchmark
dataset_id: dataset_id:
@ -5132,7 +5236,6 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- identifier - identifier
- provider_resource_id
- provider_id - provider_id
- type - type
- dataset_id - dataset_id
@ -5159,6 +5262,16 @@ components:
type: string type: string
type: type:
type: string type: string
enum:
- model
- shield
- vector_db
- dataset
- scoring_function
- benchmark
- tool
- tool_group
title: ResourceType
const: dataset const: dataset
default: dataset default: dataset
purpose: purpose:
@ -5185,7 +5298,6 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- identifier - identifier
- provider_resource_id
- provider_id - provider_id
- type - type
- purpose - purpose
@ -5284,6 +5396,16 @@ components:
type: string type: string
type: type:
type: string type: string
enum:
- model
- shield
- vector_db
- dataset
- scoring_function
- benchmark
- tool
- tool_group
title: ResourceType
const: model const: model
default: model default: model
metadata: metadata:
@ -5302,7 +5424,6 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- identifier - identifier
- provider_resource_id
- provider_id - provider_id
- type - type
- metadata - metadata
@ -5438,6 +5559,16 @@ components:
type: string type: string
type: type:
type: string type: string
enum:
- model
- shield
- vector_db
- dataset
- scoring_function
- benchmark
- tool
- tool_group
title: ResourceType
const: scoring_function const: scoring_function
default: scoring_function default: scoring_function
description: description:
@ -5459,7 +5590,6 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- identifier - identifier
- provider_resource_id
- provider_id - provider_id
- type - type
- metadata - metadata
@ -5498,6 +5628,16 @@ components:
type: string type: string
type: type:
type: string type: string
enum:
- model
- shield
- vector_db
- dataset
- scoring_function
- benchmark
- tool
- tool_group
title: ResourceType
const: shield const: shield
default: shield default: shield
params: params:
@ -5513,7 +5653,6 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- identifier - identifier
- provider_resource_id
- provider_id - provider_id
- type - type
title: Shield title: Shield
@ -5628,6 +5767,16 @@ components:
type: string type: string
type: type:
type: string type: string
enum:
- model
- shield
- vector_db
- dataset
- scoring_function
- benchmark
- tool
- tool_group
title: ResourceType
const: tool const: tool
default: tool default: tool
toolgroup_id: toolgroup_id:
@ -5653,7 +5802,6 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- identifier - identifier
- provider_resource_id
- provider_id - provider_id
- type - type
- toolgroup_id - toolgroup_id
@ -5679,6 +5827,16 @@ components:
type: string type: string
type: type:
type: string type: string
enum:
- model
- shield
- vector_db
- dataset
- scoring_function
- benchmark
- tool
- tool_group
title: ResourceType
const: tool_group const: tool_group
default: tool_group default: tool_group
mcp_endpoint: mcp_endpoint:
@ -5696,7 +5854,6 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- identifier - identifier
- provider_resource_id
- provider_id - provider_id
- type - type
title: ToolGroup title: ToolGroup
@ -5810,6 +5967,16 @@ components:
type: string type: string
type: type:
type: string type: string
enum:
- model
- shield
- vector_db
- dataset
- scoring_function
- benchmark
- tool
- tool_group
title: ResourceType
const: vector_db const: vector_db
default: vector_db default: vector_db
embedding_model: embedding_model:
@ -5819,7 +5986,6 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- identifier - identifier
- provider_resource_id
- provider_id - provider_id
- type - type
- embedding_model - embedding_model
@ -6259,6 +6425,13 @@ components:
unstructured_log: '#/components/schemas/UnstructuredLogEvent' unstructured_log: '#/components/schemas/UnstructuredLogEvent'
metric: '#/components/schemas/MetricEvent' metric: '#/components/schemas/MetricEvent'
structured_log: '#/components/schemas/StructuredLogEvent' structured_log: '#/components/schemas/StructuredLogEvent'
EventType:
type: string
enum:
- unstructured_log
- structured_log
- metric
title: EventType
LogSeverity: LogSeverity:
type: string type: string
enum: enum:
@ -6289,7 +6462,7 @@ components:
- type: boolean - type: boolean
- type: 'null' - type: 'null'
type: type:
type: string $ref: '#/components/schemas/EventType'
const: metric const: metric
default: metric default: metric
metric: metric:
@ -6314,7 +6487,7 @@ components:
type: object type: object
properties: properties:
type: type:
type: string $ref: '#/components/schemas/StructuredLogType'
const: span_end const: span_end
default: span_end default: span_end
status: status:
@ -6328,7 +6501,7 @@ components:
type: object type: object
properties: properties:
type: type:
type: string $ref: '#/components/schemas/StructuredLogType'
const: span_start const: span_start
default: span_start default: span_start
name: name:
@ -6360,7 +6533,7 @@ components:
- type: boolean - type: boolean
- type: 'null' - type: 'null'
type: type:
type: string $ref: '#/components/schemas/EventType'
const: structured_log const: structured_log
default: structured_log default: structured_log
payload: payload:
@ -6382,6 +6555,12 @@ components:
mapping: mapping:
span_start: '#/components/schemas/SpanStartPayload' span_start: '#/components/schemas/SpanStartPayload'
span_end: '#/components/schemas/SpanEndPayload' span_end: '#/components/schemas/SpanEndPayload'
StructuredLogType:
type: string
enum:
- span_start
- span_end
title: StructuredLogType
UnstructuredLogEvent: UnstructuredLogEvent:
type: object type: object
properties: properties:
@ -6402,7 +6581,7 @@ components:
- type: boolean - type: boolean
- type: 'null' - type: 'null'
type: type:
type: string $ref: '#/components/schemas/EventType'
const: unstructured_log const: unstructured_log
default: unstructured_log default: unstructured_log
message: message:

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import sys
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@ -35,6 +36,14 @@ from .openai_responses import (
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
) )
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
class Attachment(BaseModel): class Attachment(BaseModel):
"""An attachment to an agent turn. """An attachment to an agent turn.
@ -73,7 +82,7 @@ class StepCommon(BaseModel):
completed_at: datetime | None = None completed_at: datetime | None = None
class StepType(Enum): class StepType(StrEnum):
"""Type of the step in an agent turn. """Type of the step in an agent turn.
:cvar inference: The step is an inference step that calls an LLM. :cvar inference: The step is an inference step that calls an LLM.
@ -97,7 +106,7 @@ class InferenceStep(StepCommon):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference.value] = StepType.inference.value step_type: Literal[StepType.inference] = StepType.inference
model_response: CompletionMessage model_response: CompletionMessage
@ -109,7 +118,7 @@ class ToolExecutionStep(StepCommon):
:param tool_responses: The tool responses from the tool calls. :param tool_responses: The tool responses from the tool calls.
""" """
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value step_type: Literal[StepType.tool_execution] = StepType.tool_execution
tool_calls: list[ToolCall] tool_calls: list[ToolCall]
tool_responses: list[ToolResponse] tool_responses: list[ToolResponse]
@ -121,7 +130,7 @@ class ShieldCallStep(StepCommon):
:param violation: The violation from the shield call. :param violation: The violation from the shield call.
""" """
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value step_type: Literal[StepType.shield_call] = StepType.shield_call
violation: SafetyViolation | None violation: SafetyViolation | None
@ -133,7 +142,7 @@ class MemoryRetrievalStep(StepCommon):
:param inserted_context: The context retrieved from the vector databases. :param inserted_context: The context retrieved from the vector databases.
""" """
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
# TODO: should this be List[str]? # TODO: should this be List[str]?
vector_db_ids: str vector_db_ids: str
inserted_context: InterleavedContent inserted_context: InterleavedContent
@ -154,7 +163,7 @@ class Turn(BaseModel):
input_messages: list[UserMessage | ToolResponseMessage] input_messages: list[UserMessage | ToolResponseMessage]
steps: list[Step] steps: list[Step]
output_message: CompletionMessage output_message: CompletionMessage
output_attachments: list[Attachment] | None = Field(default_factory=list) output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
started_at: datetime started_at: datetime
completed_at: datetime | None = None completed_at: datetime | None = None
@ -182,10 +191,10 @@ register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel): class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: list[str] | None = Field(default_factory=list) input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=list) output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list) toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=list) client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead") tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead") tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: ToolConfig | None = Field(default=None) tool_config: ToolConfig | None = Field(default=None)
@ -246,7 +255,7 @@ class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: str | None = None instructions: str | None = None
class AgentTurnResponseEventType(Enum): class AgentTurnResponseEventType(StrEnum):
step_start = "step_start" step_start = "step_start"
step_complete = "step_complete" step_complete = "step_complete"
step_progress = "step_progress" step_progress = "step_progress"
@ -258,15 +267,15 @@ class AgentTurnResponseEventType(Enum):
@json_schema_type @json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel): class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
step_type: StepType step_type: StepType
step_id: str step_id: str
metadata: dict[str, Any] | None = Field(default_factory=dict) metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
@json_schema_type @json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel): class AgentTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
step_type: StepType step_type: StepType
step_id: str step_id: str
step_details: Step step_details: Step
@ -276,7 +285,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
class AgentTurnResponseStepProgressPayload(BaseModel): class AgentTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
step_type: StepType step_type: StepType
step_id: str step_id: str
@ -285,21 +294,19 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
@json_schema_type @json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel): class AgentTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
turn_id: str turn_id: str
@json_schema_type @json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel): class AgentTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
turn: Turn turn: Turn
@json_schema_type @json_schema_type
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel): class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = ( event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
AgentTurnResponseEventType.turn_awaiting_input.value
)
turn: Turn turn: Turn
@ -341,7 +348,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
messages: list[UserMessage | ToolResponseMessage] messages: list[UserMessage | ToolResponseMessage]
documents: list[Document] | None = None documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = None toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
stream: bool | None = False stream: bool | None = False
tool_config: ToolConfig | None = None tool_config: ToolConfig | None = None

View file

@ -22,14 +22,14 @@ class CommonBenchmarkFields(BaseModel):
@json_schema_type @json_schema_type
class Benchmark(CommonBenchmarkFields, Resource): class Benchmark(CommonBenchmarkFields, Resource):
type: Literal[ResourceType.benchmark.value] = ResourceType.benchmark.value type: Literal[ResourceType.benchmark] = ResourceType.benchmark
@property @property
def benchmark_id(self) -> str: def benchmark_id(self) -> str:
return self.identifier return self.identifier
@property @property
def provider_benchmark_id(self) -> str: def provider_benchmark_id(self) -> str | None:
return self.provider_resource_id return self.provider_resource_id

View file

@ -28,7 +28,7 @@ class _URLOrData(BaseModel):
url: URL | None = None url: URL | None = None
# data is a base64 encoded string, hint with contentEncoding=base64 # data is a base64 encoded string, hint with contentEncoding=base64
data: str | None = Field(contentEncoding="base64", default=None) data: str | None = Field(default=None, json_schema_extra={"contentEncoding": "base64"})
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View file

@ -106,14 +106,14 @@ class CommonDatasetFields(BaseModel):
@json_schema_type @json_schema_type
class Dataset(CommonDatasetFields, Resource): class Dataset(CommonDatasetFields, Resource):
type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value type: Literal[ResourceType.dataset] = ResourceType.dataset
@property @property
def dataset_id(self) -> str: def dataset_id(self) -> str:
return self.identifier return self.identifier
@property @property
def provider_dataset_id(self) -> str: def provider_dataset_id(self) -> str | None:
return self.provider_resource_id return self.provider_resource_id

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import sys
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from enum import Enum from enum import Enum
from typing import ( from typing import (
@ -35,6 +36,16 @@ register_schema(ToolCall)
register_schema(ToolParamDefinition) register_schema(ToolParamDefinition)
register_schema(ToolDefinition) register_schema(ToolDefinition)
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
@json_schema_type @json_schema_type
class GreedySamplingStrategy(BaseModel): class GreedySamplingStrategy(BaseModel):
@ -187,7 +198,7 @@ class CompletionMessage(BaseModel):
role: Literal["assistant"] = "assistant" role: Literal["assistant"] = "assistant"
content: InterleavedContent content: InterleavedContent
stop_reason: StopReason stop_reason: StopReason
tool_calls: list[ToolCall] | None = Field(default_factory=list) tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
Message = Annotated[ Message = Annotated[
@ -267,7 +278,7 @@ class ChatCompletionResponseEvent(BaseModel):
stop_reason: StopReason | None = None stop_reason: StopReason | None = None
class ResponseFormatType(Enum): class ResponseFormatType(StrEnum):
"""Types of formats for structured (guided) decoding. """Types of formats for structured (guided) decoding.
:cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model. :cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model.
@ -286,7 +297,7 @@ class JsonSchemaResponseFormat(BaseModel):
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model. :param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
""" """
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value type: Literal[ResponseFormatType.json_schema] = ResponseFormatType.json_schema
json_schema: dict[str, Any] json_schema: dict[str, Any]
@ -298,7 +309,7 @@ class GrammarResponseFormat(BaseModel):
:param bnf: The BNF grammar specification the response should conform to :param bnf: The BNF grammar specification the response should conform to
""" """
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value type: Literal[ResponseFormatType.grammar] = ResponseFormatType.grammar
bnf: dict[str, Any] bnf: dict[str, Any]
@ -394,7 +405,7 @@ class ChatCompletionRequest(BaseModel):
messages: list[Message] messages: list[Message]
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
tools: list[ToolDefinition] | None = Field(default_factory=list) tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
tool_config: ToolConfig | None = Field(default_factory=ToolConfig) tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
response_format: ResponseFormat | None = None response_format: ResponseFormat | None = None
@ -567,14 +578,14 @@ class OpenAIResponseFormatText(BaseModel):
@json_schema_type @json_schema_type
class OpenAIJSONSchema(TypedDict, total=False): class OpenAIJSONSchema(TypedDict, total=False):
name: str name: str
description: str | None = None description: str | None
strict: bool | None = None strict: bool | None
# Pydantic BaseModel cannot be used with a schema param, since it already # Pydantic BaseModel cannot be used with a schema param, since it already
# has one. And, we don't want to alias here because then have to handle # has one. And, we don't want to alias here because then have to handle
# that alias when converting to OpenAI params. So, to support schema, # that alias when converting to OpenAI params. So, to support schema,
# we use a TypedDict. # we use a TypedDict.
schema: dict[str, Any] | None = None schema: dict[str, Any] | None
@json_schema_type @json_schema_type

View file

@ -29,14 +29,14 @@ class ModelType(str, Enum):
@json_schema_type @json_schema_type
class Model(CommonModelFields, Resource): class Model(CommonModelFields, Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value type: Literal[ResourceType.model] = ResourceType.model
@property @property
def model_id(self) -> str: def model_id(self) -> str:
return self.identifier return self.identifier
@property @property
def provider_model_id(self) -> str: def provider_model_id(self) -> str | None:
return self.provider_resource_id return self.provider_resource_id
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())

View file

@ -4,12 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import sys
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class ResourceType(Enum): class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
class ResourceType(StrEnum):
model = "model" model = "model"
shield = "shield" shield = "shield"
vector_db = "vector_db" vector_db = "vector_db"
@ -25,9 +36,9 @@ class Resource(BaseModel):
identifier: str = Field(description="Unique identifier for this resource in llama stack") identifier: str = Field(description="Unique identifier for this resource in llama stack")
provider_resource_id: str = Field( provider_resource_id: str | None = Field(
description="Unique identifier for this resource in the provider",
default=None, default=None,
description="Unique identifier for this resource in the provider",
) )
provider_id: str = Field(description="ID of the provider that owns this resource") provider_id: str = Field(description="ID of the provider that owns this resource")

View file

@ -53,5 +53,5 @@ class Safety(Protocol):
self, self,
shield_id: str, shield_id: str,
messages: list[Message], messages: list[Message],
params: dict[str, Any] = None, params: dict[str, Any],
) -> RunShieldResponse: ... ) -> RunShieldResponse: ...

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# TODO: use enum.StrEnum when we drop support for python 3.10
import sys
from enum import Enum from enum import Enum
from typing import ( from typing import (
Annotated, Annotated,
@ -19,18 +21,27 @@ from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
# Perhaps more structure can be imposed on these functions. Maybe they could be associated # Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up? # with standard metrics so they can be rolled up?
@json_schema_type @json_schema_type
class ScoringFnParamsType(Enum): class ScoringFnParamsType(StrEnum):
llm_as_judge = "llm_as_judge" llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser" regex_parser = "regex_parser"
basic = "basic" basic = "basic"
@json_schema_type @json_schema_type
class AggregationFunctionType(Enum): class AggregationFunctionType(StrEnum):
average = "average" average = "average"
weighted_average = "weighted_average" weighted_average = "weighted_average"
median = "median" median = "median"
@ -40,36 +51,36 @@ class AggregationFunctionType(Enum):
@json_schema_type @json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel): class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
judge_model: str judge_model: str
prompt_template: str | None = None prompt_template: str | None = None
judge_score_regexes: list[str] | None = Field( judge_score_regexes: list[str] = Field(
description="Regexes to extract the answer from generated response", description="Regexes to extract the answer from generated response",
default_factory=list, default_factory=lambda: [],
) )
aggregation_functions: list[AggregationFunctionType] | None = Field( aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row", description="Aggregation functions to apply to the scores of each row",
default_factory=list, default_factory=lambda: [],
) )
@json_schema_type @json_schema_type
class RegexParserScoringFnParams(BaseModel): class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
parsing_regexes: list[str] | None = Field( parsing_regexes: list[str] = Field(
description="Regex to extract the answer from generated response", description="Regex to extract the answer from generated response",
default_factory=list, default_factory=lambda: [],
) )
aggregation_functions: list[AggregationFunctionType] | None = Field( aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row", description="Aggregation functions to apply to the scores of each row",
default_factory=list, default_factory=lambda: [],
) )
@json_schema_type @json_schema_type
class BasicScoringFnParams(BaseModel): class BasicScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
aggregation_functions: list[AggregationFunctionType] | None = Field( aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row", description="Aggregation functions to apply to the scores of each row",
default_factory=list, default_factory=list,
) )
@ -99,14 +110,14 @@ class CommonScoringFnFields(BaseModel):
@json_schema_type @json_schema_type
class ScoringFn(CommonScoringFnFields, Resource): class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
@property @property
def scoring_fn_id(self) -> str: def scoring_fn_id(self) -> str:
return self.identifier return self.identifier
@property @property
def provider_scoring_fn_id(self) -> str: def provider_scoring_fn_id(self) -> str | None:
return self.provider_resource_id return self.provider_resource_id

View file

@ -21,14 +21,14 @@ class CommonShieldFields(BaseModel):
class Shield(CommonShieldFields, Resource): class Shield(CommonShieldFields, Resource):
"""A safety shield resource that can be used to check content""" """A safety shield resource that can be used to check content"""
type: Literal[ResourceType.shield.value] = ResourceType.shield.value type: Literal[ResourceType.shield] = ResourceType.shield
@property @property
def shield_id(self) -> str: def shield_id(self) -> str:
return self.identifier return self.identifier
@property @property
def provider_shield_id(self) -> str: def provider_shield_id(self) -> str | None:
return self.provider_resource_id return self.provider_resource_id

View file

@ -37,7 +37,7 @@ class Span(BaseModel):
name: str name: str
start_time: datetime start_time: datetime
end_time: datetime | None = None end_time: datetime | None = None
attributes: dict[str, Any] | None = Field(default_factory=dict) attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
def set_attribute(self, key: str, value: Any): def set_attribute(self, key: str, value: Any):
if self.attributes is None: if self.attributes is None:
@ -74,19 +74,19 @@ class EventCommon(BaseModel):
trace_id: str trace_id: str
span_id: str span_id: str
timestamp: datetime timestamp: datetime
attributes: dict[str, Primitive] | None = Field(default_factory=dict) attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
@json_schema_type @json_schema_type
class UnstructuredLogEvent(EventCommon): class UnstructuredLogEvent(EventCommon):
type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
message: str message: str
severity: LogSeverity severity: LogSeverity
@json_schema_type @json_schema_type
class MetricEvent(EventCommon): class MetricEvent(EventCommon):
type: Literal[EventType.METRIC.value] = EventType.METRIC.value type: Literal[EventType.METRIC] = EventType.METRIC
metric: str # this would be an enum metric: str # this would be an enum
value: int | float value: int | float
unit: str unit: str
@ -131,14 +131,14 @@ class StructuredLogType(Enum):
@json_schema_type @json_schema_type
class SpanStartPayload(BaseModel): class SpanStartPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
name: str name: str
parent_span_id: str | None = None parent_span_id: str | None = None
@json_schema_type @json_schema_type
class SpanEndPayload(BaseModel): class SpanEndPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
status: SpanStatus status: SpanStatus
@ -151,7 +151,7 @@ register_schema(StructuredLogPayload, name="StructuredLogPayload")
@json_schema_type @json_schema_type
class StructuredLogEvent(EventCommon): class StructuredLogEvent(EventCommon):
type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
payload: StructuredLogPayload payload: StructuredLogPayload

View file

@ -36,7 +36,7 @@ class ToolHost(Enum):
@json_schema_type @json_schema_type
class Tool(Resource): class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value type: Literal[ResourceType.tool] = ResourceType.tool
toolgroup_id: str toolgroup_id: str
tool_host: ToolHost tool_host: ToolHost
description: str description: str
@ -62,7 +62,7 @@ class ToolGroupInput(BaseModel):
@json_schema_type @json_schema_type
class ToolGroup(Resource): class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value type: Literal[ResourceType.tool_group] = ResourceType.tool_group
mcp_endpoint: URL | None = None mcp_endpoint: URL | None = None
args: dict[str, Any] | None = None args: dict[str, Any] | None = None

View file

@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type @json_schema_type
class VectorDB(Resource): class VectorDB(Resource):
type: Literal[ResourceType.vector_db.value] = ResourceType.vector_db.value type: Literal[ResourceType.vector_db] = ResourceType.vector_db
embedding_model: str embedding_model: str
embedding_dimension: int embedding_dimension: int
@ -25,7 +25,7 @@ class VectorDB(Resource):
return self.identifier return self.identifier
@property @property
def provider_vector_db_id(self) -> str: def provider_vector_db_id(self) -> str | None:
return self.provider_resource_id return self.provider_resource_id

View file

@ -38,7 +38,10 @@ class LlamaCLIParser:
print_subcommand_description(self.parser, subparsers) print_subcommand_description(self.parser, subparsers)
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args() args = self.parser.parse_args()
if not isinstance(args, argparse.Namespace):
raise TypeError(f"Expected argparse.Namespace, got {type(args)}")
return args
def run(self, args: argparse.Namespace) -> None: def run(self, args: argparse.Namespace) -> None:
args.func(args) args.func(args)

View file

@ -46,7 +46,7 @@ class StackListProviders(Subcommand):
else: else:
providers = [(k.value, prov) for k, prov in all_providers.items()] providers = [(k.value, prov) for k, prov in all_providers.items()]
providers = [p for api, p in providers if api in self.providable_apis] providers = [(api, p) for api, p in providers if api in self.providable_apis]
# eventually, this should query a registry at llama.meta.com/llamastack/distributions # eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [ headers = [
@ -57,7 +57,7 @@ class StackListProviders(Subcommand):
rows = [] rows = []
specs = [spec for p in providers for spec in p.values()] specs = [spec for api, p in providers for spec in p.values()]
for spec in specs: for spec in specs:
if spec.is_sample: if spec.is_sample:
continue continue
@ -65,7 +65,7 @@ class StackListProviders(Subcommand):
[ [
spec.api.value, spec.api.value,
spec.provider_type, spec.provider_type,
",".join(spec.pip_packages), ",".join(spec.pip_packages) if hasattr(spec, "pip_packages") else "",
] ]
) )
print_table( print_table(

View file

@ -73,11 +73,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
existing_providers = config.providers.get(api_str, []) existing_providers = config.providers.get(api_str, [])
if existing_providers: if existing_providers:
logger.info( logger.info(f"Re-configuring existing providers for API `{api_str}`...")
f"Re-configuring existing providers for API `{api_str}`...",
"green",
attrs=["bold"],
)
updated_providers = [] updated_providers = []
for p in existing_providers: for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`") logger.info(f"> Configuring provider `({p.provider_type})`")
@ -91,7 +87,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
if not plist: if not plist:
raise ValueError(f"No provider configured for API {api_str}?") raise ValueError(f"No provider configured for API {api_str}?")
logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"]) logger.info(f"Configuring API `{api_str}`...")
updated_providers = [] updated_providers = []
for i, provider_type in enumerate(plist): for i, provider_type in enumerate(plist):
if i >= 1: if i >= 1:

View file

@ -30,7 +30,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec
from llama_stack.distribution.request_headers import ( from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR, PROVIDER_DATA_VAR,
request_provider_data_context, request_provider_data_context,
@ -216,7 +216,18 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
"yellow", "yellow",
) )
if self.config_path_or_template_name.endswith(".yaml"): if self.config_path_or_template_name.endswith(".yaml"):
print_pip_install_help(self.config.providers) # Convert Provider objects to their types
provider_types: dict[str, str | list[str]] = {}
for api, providers in self.config.providers.items():
types = [p.provider_type for p in providers]
# Convert single-item lists to strings
provider_types[api] = types[0] if len(types) == 1 else types
build_config = BuildConfig(
distribution_spec=DistributionSpec(
providers=provider_types,
),
)
print_pip_install_help(build_config)
else: else:
prefix = "!" if in_notebook() else "" prefix = "!" if in_notebook() else ""
cprint( cprint(

View file

@ -44,7 +44,8 @@ class RequestProviderDataContext(AbstractContextManager):
class NeedsRequestProviderData: class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any: def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__ spec = self.__provider_spec__
assert spec, f"Provider spec not set on {self.__class__}" if not spec:
raise ValueError(f"Provider spec not set on {self.__class__}")
provider_type = spec.provider_type provider_type = spec.provider_type
validator_class = spec.provider_data_validator validator_class = spec.provider_data_validator

View file

@ -124,7 +124,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
message_placeholder.markdown(full_response + "") message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response) message_placeholder.markdown(full_response)
else: else:
full_response = response full_response = response.completion_message.content
message_placeholder.markdown(full_response.completion_message.content) message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response}) st.session_state.messages.append({"role": "assistant", "content": full_response})

View file

@ -245,7 +245,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{"function_description": self._gen_function_description(custom_tools)}, {"function_description": self._gen_function_description(custom_tools)},
) )
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate: def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> str:
template_str = textwrap.dedent( template_str = textwrap.dedent(
""" """
Here is a list of functions in JSON format that you can invoke. Here is a list of functions in JSON format that you can invoke.
@ -286,10 +286,12 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
""" """
) )
return PromptTemplate( template = PromptTemplate(
template_str.strip("\n"), template_str.strip("\n"),
{"tools": [t.model_dump() for t in custom_tools]}, {"tools": [t.model_dump() for t in custom_tools]},
).render() )
rendered: str = template.render()
return rendered
def data_examples(self) -> list[list[ToolDefinition]]: def data_examples(self) -> list[list[ToolDefinition]]:
return [ return [

View file

@ -948,6 +948,8 @@ def llama_meta_net_info(model: Model) -> LlamaDownloadInfo:
elif model.core_model_id == CoreModelId.llama_guard_2_8b: elif model.core_model_id == CoreModelId.llama_guard_2_8b:
folder = "llama-guard-2" folder = "llama-guard-2"
else: else:
if model.huggingface_repo is None:
raise ValueError(f"Model {model.core_model_id} has no huggingface_repo set")
folder = model.huggingface_repo.split("/")[-1] folder = model.huggingface_repo.split("/")[-1]
if "Llama-2" in folder: if "Llama-2" in folder:
folder = folder.lower() folder = folder.lower()
@ -1024,3 +1026,4 @@ def llama_meta_pth_size(model: Model) -> int:
return 54121549657 return 54121549657
else: else:
return 100426653046 return 100426653046
return 0

View file

@ -139,6 +139,8 @@ class OllamaInferenceAdapter(
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self._get_model(model_id) model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = CompletionRequest( request = CompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
content=content, content=content,
@ -202,6 +204,8 @@ class OllamaInferenceAdapter(
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self._get_model(model_id) model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
messages=messages, messages=messages,
@ -346,6 +350,8 @@ class OllamaInferenceAdapter(
# - models not currently running are run by the ollama server as needed # - models not currently running are run by the ollama server as needed
response = await self.client.list() response = await self.client.list()
available_models = [m["model"] for m in response["models"]] available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id is None:
raise ValueError("Model provider_resource_id cannot be None")
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id) provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None: if provider_resource_id is None:
provider_resource_id = model.provider_resource_id provider_resource_id = model.provider_resource_id

View file

@ -272,6 +272,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self._get_model(model_id) model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = CompletionRequest( request = CompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
content=content, content=content,
@ -302,6 +304,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self._get_model(model_id) model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3 # This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References: # References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice # * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice

View file

@ -382,7 +382,7 @@ def augment_messages_for_tools_llama_3_1(
messages.append(SystemMessage(content=sys_content)) messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools: if has_custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json: if fmt == ToolPromptFormat.json:

View file

@ -203,58 +203,24 @@ follow_imports = "silent"
# to exclude the entire directory. # to exclude the entire directory.
exclude = [ exclude = [
# As we fix more and more of these, we should remove them from the list # As we fix more and more of these, we should remove them from the list
"^llama_stack/apis/agents/agents\\.py$",
"^llama_stack/apis/batch_inference/batch_inference\\.py$",
"^llama_stack/apis/benchmarks/benchmarks\\.py$",
"^llama_stack/apis/common/content_types\\.py$",
"^llama_stack/apis/common/training_types\\.py$", "^llama_stack/apis/common/training_types\\.py$",
"^llama_stack/apis/datasetio/datasetio\\.py$",
"^llama_stack/apis/datasets/datasets\\.py$",
"^llama_stack/apis/eval/eval\\.py$",
"^llama_stack/apis/files/files\\.py$",
"^llama_stack/apis/inference/inference\\.py$",
"^llama_stack/apis/inspect/inspect\\.py$",
"^llama_stack/apis/models/models\\.py$",
"^llama_stack/apis/post_training/post_training\\.py$",
"^llama_stack/apis/providers/providers\\.py$",
"^llama_stack/apis/resource\\.py$",
"^llama_stack/apis/safety/safety\\.py$",
"^llama_stack/apis/scoring/scoring\\.py$",
"^llama_stack/apis/scoring_functions/scoring_functions\\.py$",
"^llama_stack/apis/shields/shields\\.py$",
"^llama_stack/apis/synthetic_data_generation/synthetic_data_generation\\.py$",
"^llama_stack/apis/telemetry/telemetry\\.py$",
"^llama_stack/apis/tools/rag_tool\\.py$",
"^llama_stack/apis/tools/tools\\.py$",
"^llama_stack/apis/vector_dbs/vector_dbs\\.py$",
"^llama_stack/apis/vector_io/vector_io\\.py$",
"^llama_stack/cli/download\\.py$", "^llama_stack/cli/download\\.py$",
"^llama_stack/cli/llama\\.py$",
"^llama_stack/cli/stack/_build\\.py$", "^llama_stack/cli/stack/_build\\.py$",
"^llama_stack/cli/stack/list_providers\\.py$",
"^llama_stack/distribution/build\\.py$", "^llama_stack/distribution/build\\.py$",
"^llama_stack/distribution/client\\.py$", "^llama_stack/distribution/client\\.py$",
"^llama_stack/distribution/configure\\.py$",
"^llama_stack/distribution/library_client\\.py$",
"^llama_stack/distribution/request_headers\\.py$", "^llama_stack/distribution/request_headers\\.py$",
"^llama_stack/distribution/routers/", "^llama_stack/distribution/routers/",
"^llama_stack/distribution/server/endpoints\\.py$", "^llama_stack/distribution/server/endpoints\\.py$",
"^llama_stack/distribution/server/server\\.py$", "^llama_stack/distribution/server/server\\.py$",
"^llama_stack/distribution/server/websocket_server\\.py$",
"^llama_stack/distribution/stack\\.py$", "^llama_stack/distribution/stack\\.py$",
"^llama_stack/distribution/store/registry\\.py$", "^llama_stack/distribution/store/registry\\.py$",
"^llama_stack/distribution/ui/page/playground/chat\\.py$",
"^llama_stack/distribution/utils/exec\\.py$", "^llama_stack/distribution/utils/exec\\.py$",
"^llama_stack/distribution/utils/prompt_for_config\\.py$", "^llama_stack/distribution/utils/prompt_for_config\\.py$",
"^llama_stack/models/llama/datatypes\\.py$",
"^llama_stack/models/llama/llama3/chat_format\\.py$", "^llama_stack/models/llama/llama3/chat_format\\.py$",
"^llama_stack/models/llama/llama3/interface\\.py$", "^llama_stack/models/llama/llama3/interface\\.py$",
"^llama_stack/models/llama/llama3/prompt_templates/system_prompts\\.py$",
"^llama_stack/models/llama/llama3/tokenizer\\.py$", "^llama_stack/models/llama/llama3/tokenizer\\.py$",
"^llama_stack/models/llama/llama3/tool_utils\\.py$", "^llama_stack/models/llama/llama3/tool_utils\\.py$",
"^llama_stack/models/llama/llama3_3/prompts\\.py$", "^llama_stack/models/llama/llama3_3/prompts\\.py$",
"^llama_stack/models/llama/llama4/",
"^llama_stack/models/llama/sku_list\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/", "^llama_stack/providers/inline/agents/meta_reference/",
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
@ -333,7 +299,6 @@ exclude = [
"^llama_stack/providers/utils/telemetry/dataset_mixin\\.py$", "^llama_stack/providers/utils/telemetry/dataset_mixin\\.py$",
"^llama_stack/providers/utils/telemetry/trace_protocol\\.py$", "^llama_stack/providers/utils/telemetry/trace_protocol\\.py$",
"^llama_stack/providers/utils/telemetry/tracing\\.py$", "^llama_stack/providers/utils/telemetry/tracing\\.py$",
"^llama_stack/scripts/",
"^llama_stack/strong_typing/auxiliary\\.py$", "^llama_stack/strong_typing/auxiliary\\.py$",
"^llama_stack/strong_typing/deserializer\\.py$", "^llama_stack/strong_typing/deserializer\\.py$",
"^llama_stack/strong_typing/inspection\\.py$", "^llama_stack/strong_typing/inspection\\.py$",