From 1a529705da65ab79b9fc3eff0a4a8d9aa17b2c88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 6 May 2025 18:52:31 +0200 Subject: [PATCH] chore: more mypy fixes (#2029) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Mainly tried to cover the entire llama_stack/apis directory, we only have one left. Some excludes were just noop. Signed-off-by: Sébastien Han --- docs/_static/llama-stack-spec.html | 251 ++++++++++++++++-- docs/_static/llama-stack-spec.yaml | 219 +++++++++++++-- llama_stack/apis/agents/agents.py | 49 ++-- llama_stack/apis/benchmarks/benchmarks.py | 4 +- llama_stack/apis/common/content_types.py | 2 +- llama_stack/apis/datasets/datasets.py | 4 +- llama_stack/apis/inference/inference.py | 27 +- llama_stack/apis/models/models.py | 4 +- llama_stack/apis/resource.py | 17 +- llama_stack/apis/safety/safety.py | 2 +- .../scoring_functions/scoring_functions.py | 43 +-- llama_stack/apis/shields/shields.py | 4 +- llama_stack/apis/telemetry/telemetry.py | 14 +- llama_stack/apis/tools/tools.py | 4 +- llama_stack/apis/vector_dbs/vector_dbs.py | 4 +- llama_stack/cli/llama.py | 5 +- llama_stack/cli/stack/list_providers.py | 6 +- llama_stack/distribution/configure.py | 8 +- llama_stack/distribution/library_client.py | 15 +- llama_stack/distribution/request_headers.py | 3 +- .../distribution/ui/page/playground/chat.py | 4 +- .../llama3/prompt_templates/system_prompts.py | 8 +- llama_stack/models/llama/sku_list.py | 3 + .../remote/inference/ollama/ollama.py | 6 + .../providers/remote/inference/vllm/vllm.py | 4 + .../utils/inference/prompt_adapter.py | 2 +- pyproject.toml | 35 --- 27 files changed, 581 insertions(+), 166 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 2875f0f41..84d4cf646 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4052,9 +4052,13 @@ "properties": { "type": { "type": "string", + "enum": [ + "json_schema", + "grammar" + ], + "description": "Must be \"grammar\" to identify this format type", "const": "grammar", - "default": "grammar", - "description": "Must be \"grammar\" to identify this format type" + "default": "grammar" }, "bnf": { "type": "object", @@ -4178,9 +4182,13 @@ "properties": { "type": { "type": "string", + "enum": [ + "json_schema", + "grammar" + ], + "description": "Must be \"json_schema\" to identify this format type", "const": "json_schema", - "default": "json_schema", - "description": "Must be \"json_schema\" to identify this format type" + "default": "json_schema" }, "json_schema": { "type": "object", @@ -5638,6 +5646,14 @@ }, "step_type": { "type": "string", + "enum": [ + "inference", + "tool_execution", + "shield_call", + "memory_retrieval" + ], + "title": "StepType", + "description": "Type of the step in an agent turn.", "const": "inference", "default": "inference" }, @@ -5679,6 +5695,14 @@ }, "step_type": { "type": "string", + "enum": [ + "inference", + "tool_execution", + "shield_call", + "memory_retrieval" + ], + "title": "StepType", + "description": "Type of the step in an agent turn.", "const": "memory_retrieval", "default": "memory_retrieval" }, @@ -5767,6 +5791,14 @@ }, "step_type": { "type": "string", + "enum": [ + "inference", + "tool_execution", + "shield_call", + "memory_retrieval" + ], + "title": "StepType", + "description": "Type of the step in an agent turn.", "const": "shield_call", "default": "shield_call" }, @@ -5807,6 +5839,14 @@ }, "step_type": { "type": "string", + "enum": [ + "inference", + "tool_execution", + "shield_call", + "memory_retrieval" + ], + "title": "StepType", + "description": "Type of the step in an agent turn.", "const": "tool_execution", "default": "tool_execution" }, @@ -6069,6 +6109,15 @@ "properties": { "event_type": { "type": "string", + "enum": [ + "step_start", + "step_complete", + "step_progress", + "turn_start", + "turn_complete", + "turn_awaiting_input" + ], + "title": "AgentTurnResponseEventType", "const": "step_complete", "default": "step_complete" }, @@ -6126,6 +6175,15 @@ "properties": { "event_type": { "type": "string", + "enum": [ + "step_start", + "step_complete", + "step_progress", + "turn_start", + "turn_complete", + "turn_awaiting_input" + ], + "title": "AgentTurnResponseEventType", "const": "step_progress", "default": "step_progress" }, @@ -6161,6 +6219,15 @@ "properties": { "event_type": { "type": "string", + "enum": [ + "step_start", + "step_complete", + "step_progress", + "turn_start", + "turn_complete", + "turn_awaiting_input" + ], + "title": "AgentTurnResponseEventType", "const": "step_start", "default": "step_start" }, @@ -6231,6 +6298,15 @@ "properties": { "event_type": { "type": "string", + "enum": [ + "step_start", + "step_complete", + "step_progress", + "turn_start", + "turn_complete", + "turn_awaiting_input" + ], + "title": "AgentTurnResponseEventType", "const": "turn_awaiting_input", "default": "turn_awaiting_input" }, @@ -6250,6 +6326,15 @@ "properties": { "event_type": { "type": "string", + "enum": [ + "step_start", + "step_complete", + "step_progress", + "turn_start", + "turn_complete", + "turn_awaiting_input" + ], + "title": "AgentTurnResponseEventType", "const": "turn_complete", "default": "turn_complete" }, @@ -6269,6 +6354,15 @@ "properties": { "event_type": { "type": "string", + "enum": [ + "step_start", + "step_complete", + "step_progress", + "turn_start", + "turn_complete", + "turn_awaiting_input" + ], + "title": "AgentTurnResponseEventType", "const": "turn_start", "default": "turn_start" }, @@ -6876,7 +6970,7 @@ "type": "object", "properties": { "type": { - "type": "string", + "$ref": "#/components/schemas/ScoringFnParamsType", "const": "basic", "default": "basic" }, @@ -6889,7 +6983,8 @@ }, "additionalProperties": false, "required": [ - "type" + "type", + "aggregation_functions" ], "title": "BasicScoringFnParams" }, @@ -6941,7 +7036,7 @@ "type": "object", "properties": { "type": { - "type": "string", + "$ref": "#/components/schemas/ScoringFnParamsType", "const": "llm_as_judge", "default": "llm_as_judge" }, @@ -6967,7 +7062,9 @@ "additionalProperties": false, "required": [ "type", - "judge_model" + "judge_model", + "judge_score_regexes", + "aggregation_functions" ], "title": "LLMAsJudgeScoringFnParams" }, @@ -7005,7 +7102,7 @@ "type": "object", "properties": { "type": { - "type": "string", + "$ref": "#/components/schemas/ScoringFnParamsType", "const": "regex_parser", "default": "regex_parser" }, @@ -7024,7 +7121,9 @@ }, "additionalProperties": false, "required": [ - "type" + "type", + "parsing_regexes", + "aggregation_functions" ], "title": "RegexParserScoringFnParams" }, @@ -7049,6 +7148,15 @@ } } }, + "ScoringFnParamsType": { + "type": "string", + "enum": [ + "llm_as_judge", + "regex_parser", + "basic" + ], + "title": "ScoringFnParamsType" + }, "EvaluateRowsRequest": { "type": "object", "properties": { @@ -7317,6 +7425,17 @@ }, "type": { "type": "string", + "enum": [ + "model", + "shield", + "vector_db", + "dataset", + "scoring_function", + "benchmark", + "tool", + "tool_group" + ], + "title": "ResourceType", "const": "benchmark", "default": "benchmark" }, @@ -7358,7 +7477,6 @@ "additionalProperties": false, "required": [ "identifier", - "provider_resource_id", "provider_id", "type", "dataset_id", @@ -7398,6 +7516,17 @@ }, "type": { "type": "string", + "enum": [ + "model", + "shield", + "vector_db", + "dataset", + "scoring_function", + "benchmark", + "tool", + "tool_group" + ], + "title": "ResourceType", "const": "dataset", "default": "dataset" }, @@ -7443,7 +7572,6 @@ "additionalProperties": false, "required": [ "identifier", - "provider_resource_id", "provider_id", "type", "purpose", @@ -7573,6 +7701,17 @@ }, "type": { "type": "string", + "enum": [ + "model", + "shield", + "vector_db", + "dataset", + "scoring_function", + "benchmark", + "tool", + "tool_group" + ], + "title": "ResourceType", "const": "model", "default": "model" }, @@ -7609,7 +7748,6 @@ "additionalProperties": false, "required": [ "identifier", - "provider_resource_id", "provider_id", "type", "metadata", @@ -7808,6 +7946,17 @@ }, "type": { "type": "string", + "enum": [ + "model", + "shield", + "vector_db", + "dataset", + "scoring_function", + "benchmark", + "tool", + "tool_group" + ], + "title": "ResourceType", "const": "scoring_function", "default": "scoring_function" }, @@ -7849,7 +7998,6 @@ "additionalProperties": false, "required": [ "identifier", - "provider_resource_id", "provider_id", "type", "metadata", @@ -7901,6 +8049,17 @@ }, "type": { "type": "string", + "enum": [ + "model", + "shield", + "vector_db", + "dataset", + "scoring_function", + "benchmark", + "tool", + "tool_group" + ], + "title": "ResourceType", "const": "shield", "default": "shield" }, @@ -7933,7 +8092,6 @@ "additionalProperties": false, "required": [ "identifier", - "provider_resource_id", "provider_id", "type" ], @@ -8113,6 +8271,17 @@ }, "type": { "type": "string", + "enum": [ + "model", + "shield", + "vector_db", + "dataset", + "scoring_function", + "benchmark", + "tool", + "tool_group" + ], + "title": "ResourceType", "const": "tool", "default": "tool" }, @@ -8160,7 +8329,6 @@ "additionalProperties": false, "required": [ "identifier", - "provider_resource_id", "provider_id", "type", "toolgroup_id", @@ -8193,6 +8361,17 @@ }, "type": { "type": "string", + "enum": [ + "model", + "shield", + "vector_db", + "dataset", + "scoring_function", + "benchmark", + "tool", + "tool_group" + ], + "title": "ResourceType", "const": "tool_group", "default": "tool_group" }, @@ -8228,7 +8407,6 @@ "additionalProperties": false, "required": [ "identifier", - "provider_resource_id", "provider_id", "type" ], @@ -8395,6 +8573,17 @@ }, "type": { "type": "string", + "enum": [ + "model", + "shield", + "vector_db", + "dataset", + "scoring_function", + "benchmark", + "tool", + "tool_group" + ], + "title": "ResourceType", "const": "vector_db", "default": "vector_db" }, @@ -8408,7 +8597,6 @@ "additionalProperties": false, "required": [ "identifier", - "provider_resource_id", "provider_id", "type", "embedding_model", @@ -9110,6 +9298,15 @@ } } }, + "EventType": { + "type": "string", + "enum": [ + "unstructured_log", + "structured_log", + "metric" + ], + "title": "EventType" + }, "LogSeverity": { "type": "string", "enum": [ @@ -9158,7 +9355,7 @@ } }, "type": { - "type": "string", + "$ref": "#/components/schemas/EventType", "const": "metric", "default": "metric" }, @@ -9195,7 +9392,7 @@ "type": "object", "properties": { "type": { - "type": "string", + "$ref": "#/components/schemas/StructuredLogType", "const": "span_end", "default": "span_end" }, @@ -9214,7 +9411,7 @@ "type": "object", "properties": { "type": { - "type": "string", + "$ref": "#/components/schemas/StructuredLogType", "const": "span_start", "default": "span_start" }, @@ -9268,7 +9465,7 @@ } }, "type": { - "type": "string", + "$ref": "#/components/schemas/EventType", "const": "structured_log", "default": "structured_log" }, @@ -9303,6 +9500,14 @@ } } }, + "StructuredLogType": { + "type": "string", + "enum": [ + "span_start", + "span_end" + ], + "title": "StructuredLogType" + }, "UnstructuredLogEvent": { "type": "object", "properties": { @@ -9339,7 +9544,7 @@ } }, "type": { - "type": "string", + "$ref": "#/components/schemas/EventType", "const": "unstructured_log", "default": "unstructured_log" }, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index cb73919d9..259cb3007 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2812,10 +2812,13 @@ components: properties: type: type: string - const: grammar - default: grammar + enum: + - json_schema + - grammar description: >- Must be "grammar" to identify this format type + const: grammar + default: grammar bnf: type: object additionalProperties: @@ -2897,10 +2900,13 @@ components: properties: type: type: string - const: json_schema - default: json_schema + enum: + - json_schema + - grammar description: >- Must be "json_schema" to identify this format type + const: json_schema + default: json_schema json_schema: type: object additionalProperties: @@ -3959,6 +3965,13 @@ components: description: The time the step completed. step_type: type: string + enum: + - inference + - tool_execution + - shield_call + - memory_retrieval + title: StepType + description: Type of the step in an agent turn. const: inference default: inference model_response: @@ -3991,6 +4004,13 @@ components: description: The time the step completed. step_type: type: string + enum: + - inference + - tool_execution + - shield_call + - memory_retrieval + title: StepType + description: Type of the step in an agent turn. const: memory_retrieval default: memory_retrieval vector_db_ids: @@ -4052,6 +4072,13 @@ components: description: The time the step completed. step_type: type: string + enum: + - inference + - tool_execution + - shield_call + - memory_retrieval + title: StepType + description: Type of the step in an agent turn. const: shield_call default: shield_call violation: @@ -4083,6 +4110,13 @@ components: description: The time the step completed. step_type: type: string + enum: + - inference + - tool_execution + - shield_call + - memory_retrieval + title: StepType + description: Type of the step in an agent turn. const: tool_execution default: tool_execution tool_calls: @@ -4245,6 +4279,14 @@ components: properties: event_type: type: string + enum: + - step_start + - step_complete + - step_progress + - turn_start + - turn_complete + - turn_awaiting_input + title: AgentTurnResponseEventType const: step_complete default: step_complete step_type: @@ -4283,6 +4325,14 @@ components: properties: event_type: type: string + enum: + - step_start + - step_complete + - step_progress + - turn_start + - turn_complete + - turn_awaiting_input + title: AgentTurnResponseEventType const: step_progress default: step_progress step_type: @@ -4310,6 +4360,14 @@ components: properties: event_type: type: string + enum: + - step_start + - step_complete + - step_progress + - turn_start + - turn_complete + - turn_awaiting_input + title: AgentTurnResponseEventType const: step_start default: step_start step_type: @@ -4354,6 +4412,14 @@ components: properties: event_type: type: string + enum: + - step_start + - step_complete + - step_progress + - turn_start + - turn_complete + - turn_awaiting_input + title: AgentTurnResponseEventType const: turn_awaiting_input default: turn_awaiting_input turn: @@ -4369,6 +4435,14 @@ components: properties: event_type: type: string + enum: + - step_start + - step_complete + - step_progress + - turn_start + - turn_complete + - turn_awaiting_input + title: AgentTurnResponseEventType const: turn_complete default: turn_complete turn: @@ -4383,6 +4457,14 @@ components: properties: event_type: type: string + enum: + - step_start + - step_complete + - step_progress + - turn_start + - turn_complete + - turn_awaiting_input + title: AgentTurnResponseEventType const: turn_start default: turn_start turn_id: @@ -4825,7 +4907,7 @@ components: type: object properties: type: - type: string + $ref: '#/components/schemas/ScoringFnParamsType' const: basic default: basic aggregation_functions: @@ -4835,6 +4917,7 @@ components: additionalProperties: false required: - type + - aggregation_functions title: BasicScoringFnParams BenchmarkConfig: type: object @@ -4874,7 +4957,7 @@ components: type: object properties: type: - type: string + $ref: '#/components/schemas/ScoringFnParamsType' const: llm_as_judge default: llm_as_judge judge_model: @@ -4893,6 +4976,8 @@ components: required: - type - judge_model + - judge_score_regexes + - aggregation_functions title: LLMAsJudgeScoringFnParams ModelCandidate: type: object @@ -4923,7 +5008,7 @@ components: type: object properties: type: - type: string + $ref: '#/components/schemas/ScoringFnParamsType' const: regex_parser default: regex_parser parsing_regexes: @@ -4937,6 +5022,8 @@ components: additionalProperties: false required: - type + - parsing_regexes + - aggregation_functions title: RegexParserScoringFnParams ScoringFnParams: oneOf: @@ -4949,6 +5036,13 @@ components: llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams' regex_parser: '#/components/schemas/RegexParserScoringFnParams' basic: '#/components/schemas/BasicScoringFnParams' + ScoringFnParamsType: + type: string + enum: + - llm_as_judge + - regex_parser + - basic + title: ScoringFnParamsType EvaluateRowsRequest: type: object properties: @@ -5111,6 +5205,16 @@ components: type: string type: type: string + enum: + - model + - shield + - vector_db + - dataset + - scoring_function + - benchmark + - tool + - tool_group + title: ResourceType const: benchmark default: benchmark dataset_id: @@ -5132,7 +5236,6 @@ components: additionalProperties: false required: - identifier - - provider_resource_id - provider_id - type - dataset_id @@ -5159,6 +5262,16 @@ components: type: string type: type: string + enum: + - model + - shield + - vector_db + - dataset + - scoring_function + - benchmark + - tool + - tool_group + title: ResourceType const: dataset default: dataset purpose: @@ -5185,7 +5298,6 @@ components: additionalProperties: false required: - identifier - - provider_resource_id - provider_id - type - purpose @@ -5284,6 +5396,16 @@ components: type: string type: type: string + enum: + - model + - shield + - vector_db + - dataset + - scoring_function + - benchmark + - tool + - tool_group + title: ResourceType const: model default: model metadata: @@ -5302,7 +5424,6 @@ components: additionalProperties: false required: - identifier - - provider_resource_id - provider_id - type - metadata @@ -5438,6 +5559,16 @@ components: type: string type: type: string + enum: + - model + - shield + - vector_db + - dataset + - scoring_function + - benchmark + - tool + - tool_group + title: ResourceType const: scoring_function default: scoring_function description: @@ -5459,7 +5590,6 @@ components: additionalProperties: false required: - identifier - - provider_resource_id - provider_id - type - metadata @@ -5498,6 +5628,16 @@ components: type: string type: type: string + enum: + - model + - shield + - vector_db + - dataset + - scoring_function + - benchmark + - tool + - tool_group + title: ResourceType const: shield default: shield params: @@ -5513,7 +5653,6 @@ components: additionalProperties: false required: - identifier - - provider_resource_id - provider_id - type title: Shield @@ -5628,6 +5767,16 @@ components: type: string type: type: string + enum: + - model + - shield + - vector_db + - dataset + - scoring_function + - benchmark + - tool + - tool_group + title: ResourceType const: tool default: tool toolgroup_id: @@ -5653,7 +5802,6 @@ components: additionalProperties: false required: - identifier - - provider_resource_id - provider_id - type - toolgroup_id @@ -5679,6 +5827,16 @@ components: type: string type: type: string + enum: + - model + - shield + - vector_db + - dataset + - scoring_function + - benchmark + - tool + - tool_group + title: ResourceType const: tool_group default: tool_group mcp_endpoint: @@ -5696,7 +5854,6 @@ components: additionalProperties: false required: - identifier - - provider_resource_id - provider_id - type title: ToolGroup @@ -5810,6 +5967,16 @@ components: type: string type: type: string + enum: + - model + - shield + - vector_db + - dataset + - scoring_function + - benchmark + - tool + - tool_group + title: ResourceType const: vector_db default: vector_db embedding_model: @@ -5819,7 +5986,6 @@ components: additionalProperties: false required: - identifier - - provider_resource_id - provider_id - type - embedding_model @@ -6259,6 +6425,13 @@ components: unstructured_log: '#/components/schemas/UnstructuredLogEvent' metric: '#/components/schemas/MetricEvent' structured_log: '#/components/schemas/StructuredLogEvent' + EventType: + type: string + enum: + - unstructured_log + - structured_log + - metric + title: EventType LogSeverity: type: string enum: @@ -6289,7 +6462,7 @@ components: - type: boolean - type: 'null' type: - type: string + $ref: '#/components/schemas/EventType' const: metric default: metric metric: @@ -6314,7 +6487,7 @@ components: type: object properties: type: - type: string + $ref: '#/components/schemas/StructuredLogType' const: span_end default: span_end status: @@ -6328,7 +6501,7 @@ components: type: object properties: type: - type: string + $ref: '#/components/schemas/StructuredLogType' const: span_start default: span_start name: @@ -6360,7 +6533,7 @@ components: - type: boolean - type: 'null' type: - type: string + $ref: '#/components/schemas/EventType' const: structured_log default: structured_log payload: @@ -6382,6 +6555,12 @@ components: mapping: span_start: '#/components/schemas/SpanStartPayload' span_end: '#/components/schemas/SpanEndPayload' + StructuredLogType: + type: string + enum: + - span_start + - span_end + title: StructuredLogType UnstructuredLogEvent: type: object properties: @@ -6402,7 +6581,7 @@ components: - type: boolean - type: 'null' type: - type: string + $ref: '#/components/schemas/EventType' const: unstructured_log default: unstructured_log message: diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 30b37e98f..84e3a9057 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import sys from collections.abc import AsyncIterator from datetime import datetime from enum import Enum @@ -35,6 +36,14 @@ from .openai_responses import ( OpenAIResponseObjectStream, ) +# TODO: use enum.StrEnum when we drop support for python 3.10 +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + + class StrEnum(str, Enum): + """Backport of StrEnum for Python 3.10 and below.""" + class Attachment(BaseModel): """An attachment to an agent turn. @@ -73,7 +82,7 @@ class StepCommon(BaseModel): completed_at: datetime | None = None -class StepType(Enum): +class StepType(StrEnum): """Type of the step in an agent turn. :cvar inference: The step is an inference step that calls an LLM. @@ -97,7 +106,7 @@ class InferenceStep(StepCommon): model_config = ConfigDict(protected_namespaces=()) - step_type: Literal[StepType.inference.value] = StepType.inference.value + step_type: Literal[StepType.inference] = StepType.inference model_response: CompletionMessage @@ -109,7 +118,7 @@ class ToolExecutionStep(StepCommon): :param tool_responses: The tool responses from the tool calls. """ - step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value + step_type: Literal[StepType.tool_execution] = StepType.tool_execution tool_calls: list[ToolCall] tool_responses: list[ToolResponse] @@ -121,7 +130,7 @@ class ShieldCallStep(StepCommon): :param violation: The violation from the shield call. """ - step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value + step_type: Literal[StepType.shield_call] = StepType.shield_call violation: SafetyViolation | None @@ -133,7 +142,7 @@ class MemoryRetrievalStep(StepCommon): :param inserted_context: The context retrieved from the vector databases. """ - step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value + step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval # TODO: should this be List[str]? vector_db_ids: str inserted_context: InterleavedContent @@ -154,7 +163,7 @@ class Turn(BaseModel): input_messages: list[UserMessage | ToolResponseMessage] steps: list[Step] output_message: CompletionMessage - output_attachments: list[Attachment] | None = Field(default_factory=list) + output_attachments: list[Attachment] | None = Field(default_factory=lambda: []) started_at: datetime completed_at: datetime | None = None @@ -182,10 +191,10 @@ register_schema(AgentToolGroup, name="AgentTool") class AgentConfigCommon(BaseModel): sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) - input_shields: list[str] | None = Field(default_factory=list) - output_shields: list[str] | None = Field(default_factory=list) - toolgroups: list[AgentToolGroup] | None = Field(default_factory=list) - client_tools: list[ToolDef] | None = Field(default_factory=list) + input_shields: list[str] | None = Field(default_factory=lambda: []) + output_shields: list[str] | None = Field(default_factory=lambda: []) + toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: []) + client_tools: list[ToolDef] | None = Field(default_factory=lambda: []) tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead") tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead") tool_config: ToolConfig | None = Field(default=None) @@ -246,7 +255,7 @@ class AgentConfigOverridablePerTurn(AgentConfigCommon): instructions: str | None = None -class AgentTurnResponseEventType(Enum): +class AgentTurnResponseEventType(StrEnum): step_start = "step_start" step_complete = "step_complete" step_progress = "step_progress" @@ -258,15 +267,15 @@ class AgentTurnResponseEventType(Enum): @json_schema_type class AgentTurnResponseStepStartPayload(BaseModel): - event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value + event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start step_type: StepType step_id: str - metadata: dict[str, Any] | None = Field(default_factory=dict) + metadata: dict[str, Any] | None = Field(default_factory=lambda: {}) @json_schema_type class AgentTurnResponseStepCompletePayload(BaseModel): - event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value + event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete step_type: StepType step_id: str step_details: Step @@ -276,7 +285,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel): class AgentTurnResponseStepProgressPayload(BaseModel): model_config = ConfigDict(protected_namespaces=()) - event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value + event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress step_type: StepType step_id: str @@ -285,21 +294,19 @@ class AgentTurnResponseStepProgressPayload(BaseModel): @json_schema_type class AgentTurnResponseTurnStartPayload(BaseModel): - event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value + event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start turn_id: str @json_schema_type class AgentTurnResponseTurnCompletePayload(BaseModel): - event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value + event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete turn: Turn @json_schema_type class AgentTurnResponseTurnAwaitingInputPayload(BaseModel): - event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = ( - AgentTurnResponseEventType.turn_awaiting_input.value - ) + event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input turn: Turn @@ -341,7 +348,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): messages: list[UserMessage | ToolResponseMessage] documents: list[Document] | None = None - toolgroups: list[AgentToolGroup] | None = None + toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: []) stream: bool | None = False tool_config: ToolConfig | None = None diff --git a/llama_stack/apis/benchmarks/benchmarks.py b/llama_stack/apis/benchmarks/benchmarks.py index 1bba42d20..e3b0502bc 100644 --- a/llama_stack/apis/benchmarks/benchmarks.py +++ b/llama_stack/apis/benchmarks/benchmarks.py @@ -22,14 +22,14 @@ class CommonBenchmarkFields(BaseModel): @json_schema_type class Benchmark(CommonBenchmarkFields, Resource): - type: Literal[ResourceType.benchmark.value] = ResourceType.benchmark.value + type: Literal[ResourceType.benchmark] = ResourceType.benchmark @property def benchmark_id(self) -> str: return self.identifier @property - def provider_benchmark_id(self) -> str: + def provider_benchmark_id(self) -> str | None: return self.provider_resource_id diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index b9ef033dd..8bcb781f7 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -28,7 +28,7 @@ class _URLOrData(BaseModel): url: URL | None = None # data is a base64 encoded string, hint with contentEncoding=base64 - data: str | None = Field(contentEncoding="base64", default=None) + data: str | None = Field(default=None, json_schema_extra={"contentEncoding": "base64"}) @model_validator(mode="before") @classmethod diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 796217557..a0ee987ad 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -106,14 +106,14 @@ class CommonDatasetFields(BaseModel): @json_schema_type class Dataset(CommonDatasetFields, Resource): - type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value + type: Literal[ResourceType.dataset] = ResourceType.dataset @property def dataset_id(self) -> str: return self.identifier @property - def provider_dataset_id(self) -> str: + def provider_dataset_id(self) -> str | None: return self.provider_resource_id diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index dbcd1d019..00050779b 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import sys from collections.abc import AsyncIterator from enum import Enum from typing import ( @@ -35,6 +36,16 @@ register_schema(ToolCall) register_schema(ToolParamDefinition) register_schema(ToolDefinition) +# TODO: use enum.StrEnum when we drop support for python 3.10 +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + + class StrEnum(str, Enum): + """Backport of StrEnum for Python 3.10 and below.""" + + pass + @json_schema_type class GreedySamplingStrategy(BaseModel): @@ -187,7 +198,7 @@ class CompletionMessage(BaseModel): role: Literal["assistant"] = "assistant" content: InterleavedContent stop_reason: StopReason - tool_calls: list[ToolCall] | None = Field(default_factory=list) + tool_calls: list[ToolCall] | None = Field(default_factory=lambda: []) Message = Annotated[ @@ -267,7 +278,7 @@ class ChatCompletionResponseEvent(BaseModel): stop_reason: StopReason | None = None -class ResponseFormatType(Enum): +class ResponseFormatType(StrEnum): """Types of formats for structured (guided) decoding. :cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model. @@ -286,7 +297,7 @@ class JsonSchemaResponseFormat(BaseModel): :param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model. """ - type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value + type: Literal[ResponseFormatType.json_schema] = ResponseFormatType.json_schema json_schema: dict[str, Any] @@ -298,7 +309,7 @@ class GrammarResponseFormat(BaseModel): :param bnf: The BNF grammar specification the response should conform to """ - type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value + type: Literal[ResponseFormatType.grammar] = ResponseFormatType.grammar bnf: dict[str, Any] @@ -394,7 +405,7 @@ class ChatCompletionRequest(BaseModel): messages: list[Message] sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) - tools: list[ToolDefinition] | None = Field(default_factory=list) + tools: list[ToolDefinition] | None = Field(default_factory=lambda: []) tool_config: ToolConfig | None = Field(default_factory=ToolConfig) response_format: ResponseFormat | None = None @@ -567,14 +578,14 @@ class OpenAIResponseFormatText(BaseModel): @json_schema_type class OpenAIJSONSchema(TypedDict, total=False): name: str - description: str | None = None - strict: bool | None = None + description: str | None + strict: bool | None # Pydantic BaseModel cannot be used with a schema param, since it already # has one. And, we don't want to alias here because then have to handle # that alias when converting to OpenAI params. So, to support schema, # we use a TypedDict. - schema: dict[str, Any] | None = None + schema: dict[str, Any] | None @json_schema_type diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 5d7b5aac6..37ae95fa5 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -29,14 +29,14 @@ class ModelType(str, Enum): @json_schema_type class Model(CommonModelFields, Resource): - type: Literal[ResourceType.model.value] = ResourceType.model.value + type: Literal[ResourceType.model] = ResourceType.model @property def model_id(self) -> str: return self.identifier @property - def provider_model_id(self) -> str: + def provider_model_id(self) -> str | None: return self.provider_resource_id model_config = ConfigDict(protected_namespaces=()) diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index 70ec63c55..175baa7b9 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -4,12 +4,23 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import sys from enum import Enum from pydantic import BaseModel, Field +# TODO: use enum.StrEnum when we drop support for python 3.10 +if sys.version_info >= (3, 11): + from enum import StrEnum +else: -class ResourceType(Enum): + class StrEnum(str, Enum): + """Backport of StrEnum for Python 3.10 and below.""" + + pass + + +class ResourceType(StrEnum): model = "model" shield = "shield" vector_db = "vector_db" @@ -25,9 +36,9 @@ class Resource(BaseModel): identifier: str = Field(description="Unique identifier for this resource in llama stack") - provider_resource_id: str = Field( - description="Unique identifier for this resource in the provider", + provider_resource_id: str | None = Field( default=None, + description="Unique identifier for this resource in the provider", ) provider_id: str = Field(description="ID of the provider that owns this resource") diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index e139f5ffc..b6b58262f 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -53,5 +53,5 @@ class Safety(Protocol): self, shield_id: str, messages: list[Message], - params: dict[str, Any] = None, + params: dict[str, Any], ) -> RunShieldResponse: ... diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 6c7819965..9ba9eb654 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +# TODO: use enum.StrEnum when we drop support for python 3.10 +import sys from enum import Enum from typing import ( Annotated, @@ -19,18 +21,27 @@ from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.resource import Resource, ResourceType from llama_stack.schema_utils import json_schema_type, register_schema, webmethod +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + + class StrEnum(str, Enum): + """Backport of StrEnum for Python 3.10 and below.""" + + pass + # Perhaps more structure can be imposed on these functions. Maybe they could be associated # with standard metrics so they can be rolled up? @json_schema_type -class ScoringFnParamsType(Enum): +class ScoringFnParamsType(StrEnum): llm_as_judge = "llm_as_judge" regex_parser = "regex_parser" basic = "basic" @json_schema_type -class AggregationFunctionType(Enum): +class AggregationFunctionType(StrEnum): average = "average" weighted_average = "weighted_average" median = "median" @@ -40,36 +51,36 @@ class AggregationFunctionType(Enum): @json_schema_type class LLMAsJudgeScoringFnParams(BaseModel): - type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value + type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge judge_model: str prompt_template: str | None = None - judge_score_regexes: list[str] | None = Field( + judge_score_regexes: list[str] = Field( description="Regexes to extract the answer from generated response", - default_factory=list, + default_factory=lambda: [], ) - aggregation_functions: list[AggregationFunctionType] | None = Field( + aggregation_functions: list[AggregationFunctionType] = Field( description="Aggregation functions to apply to the scores of each row", - default_factory=list, + default_factory=lambda: [], ) @json_schema_type class RegexParserScoringFnParams(BaseModel): - type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value - parsing_regexes: list[str] | None = Field( + type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser + parsing_regexes: list[str] = Field( description="Regex to extract the answer from generated response", - default_factory=list, + default_factory=lambda: [], ) - aggregation_functions: list[AggregationFunctionType] | None = Field( + aggregation_functions: list[AggregationFunctionType] = Field( description="Aggregation functions to apply to the scores of each row", - default_factory=list, + default_factory=lambda: [], ) @json_schema_type class BasicScoringFnParams(BaseModel): - type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value - aggregation_functions: list[AggregationFunctionType] | None = Field( + type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic + aggregation_functions: list[AggregationFunctionType] = Field( description="Aggregation functions to apply to the scores of each row", default_factory=list, ) @@ -99,14 +110,14 @@ class CommonScoringFnFields(BaseModel): @json_schema_type class ScoringFn(CommonScoringFnFields, Resource): - type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value + type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function @property def scoring_fn_id(self) -> str: return self.identifier @property - def provider_scoring_fn_id(self) -> str: + def provider_scoring_fn_id(self) -> str | None: return self.provider_resource_id diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 4172fcbd1..66bb9a0b8 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -21,14 +21,14 @@ class CommonShieldFields(BaseModel): class Shield(CommonShieldFields, Resource): """A safety shield resource that can be used to check content""" - type: Literal[ResourceType.shield.value] = ResourceType.shield.value + type: Literal[ResourceType.shield] = ResourceType.shield @property def shield_id(self) -> str: return self.identifier @property - def provider_shield_id(self) -> str: + def provider_shield_id(self) -> str | None: return self.provider_resource_id diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index af0469d2b..34e296fef 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -37,7 +37,7 @@ class Span(BaseModel): name: str start_time: datetime end_time: datetime | None = None - attributes: dict[str, Any] | None = Field(default_factory=dict) + attributes: dict[str, Any] | None = Field(default_factory=lambda: {}) def set_attribute(self, key: str, value: Any): if self.attributes is None: @@ -74,19 +74,19 @@ class EventCommon(BaseModel): trace_id: str span_id: str timestamp: datetime - attributes: dict[str, Primitive] | None = Field(default_factory=dict) + attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {}) @json_schema_type class UnstructuredLogEvent(EventCommon): - type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value + type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG message: str severity: LogSeverity @json_schema_type class MetricEvent(EventCommon): - type: Literal[EventType.METRIC.value] = EventType.METRIC.value + type: Literal[EventType.METRIC] = EventType.METRIC metric: str # this would be an enum value: int | float unit: str @@ -131,14 +131,14 @@ class StructuredLogType(Enum): @json_schema_type class SpanStartPayload(BaseModel): - type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value + type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START name: str parent_span_id: str | None = None @json_schema_type class SpanEndPayload(BaseModel): - type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value + type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END status: SpanStatus @@ -151,7 +151,7 @@ register_schema(StructuredLogPayload, name="StructuredLogPayload") @json_schema_type class StructuredLogEvent(EventCommon): - type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value + type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG payload: StructuredLogPayload diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index eda627932..2860ddbd8 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -36,7 +36,7 @@ class ToolHost(Enum): @json_schema_type class Tool(Resource): - type: Literal[ResourceType.tool.value] = ResourceType.tool.value + type: Literal[ResourceType.tool] = ResourceType.tool toolgroup_id: str tool_host: ToolHost description: str @@ -62,7 +62,7 @@ class ToolGroupInput(BaseModel): @json_schema_type class ToolGroup(Resource): - type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value + type: Literal[ResourceType.tool_group] = ResourceType.tool_group mcp_endpoint: URL | None = None args: dict[str, Any] | None = None diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index 6224566cd..a01892888 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod @json_schema_type class VectorDB(Resource): - type: Literal[ResourceType.vector_db.value] = ResourceType.vector_db.value + type: Literal[ResourceType.vector_db] = ResourceType.vector_db embedding_model: str embedding_dimension: int @@ -25,7 +25,7 @@ class VectorDB(Resource): return self.identifier @property - def provider_vector_db_id(self) -> str: + def provider_vector_db_id(self) -> str | None: return self.provider_resource_id diff --git a/llama_stack/cli/llama.py b/llama_stack/cli/llama.py index 8ff580029..433b311e7 100644 --- a/llama_stack/cli/llama.py +++ b/llama_stack/cli/llama.py @@ -38,7 +38,10 @@ class LlamaCLIParser: print_subcommand_description(self.parser, subparsers) def parse_args(self) -> argparse.Namespace: - return self.parser.parse_args() + args = self.parser.parse_args() + if not isinstance(args, argparse.Namespace): + raise TypeError(f"Expected argparse.Namespace, got {type(args)}") + return args def run(self, args: argparse.Namespace) -> None: args.func(args) diff --git a/llama_stack/cli/stack/list_providers.py b/llama_stack/cli/stack/list_providers.py index bfe11aa2c..deebd937b 100644 --- a/llama_stack/cli/stack/list_providers.py +++ b/llama_stack/cli/stack/list_providers.py @@ -46,7 +46,7 @@ class StackListProviders(Subcommand): else: providers = [(k.value, prov) for k, prov in all_providers.items()] - providers = [p for api, p in providers if api in self.providable_apis] + providers = [(api, p) for api, p in providers if api in self.providable_apis] # eventually, this should query a registry at llama.meta.com/llamastack/distributions headers = [ @@ -57,7 +57,7 @@ class StackListProviders(Subcommand): rows = [] - specs = [spec for p in providers for spec in p.values()] + specs = [spec for api, p in providers for spec in p.values()] for spec in specs: if spec.is_sample: continue @@ -65,7 +65,7 @@ class StackListProviders(Subcommand): [ spec.api.value, spec.provider_type, - ",".join(spec.pip_packages), + ",".join(spec.pip_packages) if hasattr(spec, "pip_packages") else "", ] ) print_table( diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 76167258a..78a6a184e 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -73,11 +73,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec existing_providers = config.providers.get(api_str, []) if existing_providers: - logger.info( - f"Re-configuring existing providers for API `{api_str}`...", - "green", - attrs=["bold"], - ) + logger.info(f"Re-configuring existing providers for API `{api_str}`...") updated_providers = [] for p in existing_providers: logger.info(f"> Configuring provider `({p.provider_type})`") @@ -91,7 +87,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec if not plist: raise ValueError(f"No provider configured for API {api_str}?") - logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"]) + logger.info(f"Configuring API `{api_str}`...") updated_providers = [] for i, provider_type in enumerate(plist): if i >= 1: diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index b2d16d74c..8e5445874 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -30,7 +30,7 @@ from termcolor import cprint from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config -from llama_stack.distribution.datatypes import Api +from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec from llama_stack.distribution.request_headers import ( PROVIDER_DATA_VAR, request_provider_data_context, @@ -216,7 +216,18 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): "yellow", ) if self.config_path_or_template_name.endswith(".yaml"): - print_pip_install_help(self.config.providers) + # Convert Provider objects to their types + provider_types: dict[str, str | list[str]] = {} + for api, providers in self.config.providers.items(): + types = [p.provider_type for p in providers] + # Convert single-item lists to strings + provider_types[api] = types[0] if len(types) == 1 else types + build_config = BuildConfig( + distribution_spec=DistributionSpec( + providers=provider_types, + ), + ) + print_pip_install_help(build_config) else: prefix = "!" if in_notebook() else "" cprint( diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index bc15776ec..b03d2dee8 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -44,7 +44,8 @@ class RequestProviderDataContext(AbstractContextManager): class NeedsRequestProviderData: def get_request_provider_data(self) -> Any: spec = self.__provider_spec__ - assert spec, f"Provider spec not set on {self.__class__}" + if not spec: + raise ValueError(f"Provider spec not set on {self.__class__}") provider_type = spec.provider_type validator_class = spec.provider_data_validator diff --git a/llama_stack/distribution/ui/page/playground/chat.py b/llama_stack/distribution/ui/page/playground/chat.py index 8e7345169..fcaf08795 100644 --- a/llama_stack/distribution/ui/page/playground/chat.py +++ b/llama_stack/distribution/ui/page/playground/chat.py @@ -124,7 +124,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"): message_placeholder.markdown(full_response + "▌") message_placeholder.markdown(full_response) else: - full_response = response - message_placeholder.markdown(full_response.completion_message.content) + full_response = response.completion_message.content + message_placeholder.markdown(full_response) st.session_state.messages.append({"role": "assistant", "content": full_response}) diff --git a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py index 8e6f97012..ab626e5af 100644 --- a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +++ b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py @@ -245,7 +245,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 {"function_description": self._gen_function_description(custom_tools)}, ) - def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate: + def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> str: template_str = textwrap.dedent( """ Here is a list of functions in JSON format that you can invoke. @@ -286,10 +286,12 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 """ ) - return PromptTemplate( + template = PromptTemplate( template_str.strip("\n"), {"tools": [t.model_dump() for t in custom_tools]}, - ).render() + ) + rendered: str = template.render() + return rendered def data_examples(self) -> list[list[ToolDefinition]]: return [ diff --git a/llama_stack/models/llama/sku_list.py b/llama_stack/models/llama/sku_list.py index a82cbf708..271cec63f 100644 --- a/llama_stack/models/llama/sku_list.py +++ b/llama_stack/models/llama/sku_list.py @@ -948,6 +948,8 @@ def llama_meta_net_info(model: Model) -> LlamaDownloadInfo: elif model.core_model_id == CoreModelId.llama_guard_2_8b: folder = "llama-guard-2" else: + if model.huggingface_repo is None: + raise ValueError(f"Model {model.core_model_id} has no huggingface_repo set") folder = model.huggingface_repo.split("/")[-1] if "Llama-2" in folder: folder = folder.lower() @@ -1024,3 +1026,4 @@ def llama_meta_pth_size(model: Model) -> int: return 54121549657 else: return 100426653046 + return 0 diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 0cf63097b..d0d45b429 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -139,6 +139,8 @@ class OllamaInferenceAdapter( if sampling_params is None: sampling_params = SamplingParams() model = await self._get_model(model_id) + if model.provider_resource_id is None: + raise ValueError(f"Model {model_id} has no provider_resource_id set") request = CompletionRequest( model=model.provider_resource_id, content=content, @@ -202,6 +204,8 @@ class OllamaInferenceAdapter( if sampling_params is None: sampling_params = SamplingParams() model = await self._get_model(model_id) + if model.provider_resource_id is None: + raise ValueError(f"Model {model_id} has no provider_resource_id set") request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, @@ -346,6 +350,8 @@ class OllamaInferenceAdapter( # - models not currently running are run by the ollama server as needed response = await self.client.list() available_models = [m["model"] for m in response["models"]] + if model.provider_resource_id is None: + raise ValueError("Model provider_resource_id cannot be None") provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id) if provider_resource_id is None: provider_resource_id = model.provider_resource_id diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index addf2d35b..8bc733fd3 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -272,6 +272,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): if sampling_params is None: sampling_params = SamplingParams() model = await self._get_model(model_id) + if model.provider_resource_id is None: + raise ValueError(f"Model {model_id} has no provider_resource_id set") request = CompletionRequest( model=model.provider_resource_id, content=content, @@ -302,6 +304,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): if sampling_params is None: sampling_params = SamplingParams() model = await self._get_model(model_id) + if model.provider_resource_id is None: + raise ValueError(f"Model {model_id} has no provider_resource_id set") # This is to be consistent with OpenAI API and support vLLM <= v0.6.3 # References: # * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index d53b51537..56e33cfdf 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -382,7 +382,7 @@ def augment_messages_for_tools_llama_3_1( messages.append(SystemMessage(content=sys_content)) - has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) + has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools) if has_custom_tools: fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json if fmt == ToolPromptFormat.json: diff --git a/pyproject.toml b/pyproject.toml index c47e4ec9e..d3cc819be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -203,58 +203,24 @@ follow_imports = "silent" # to exclude the entire directory. exclude = [ # As we fix more and more of these, we should remove them from the list - "^llama_stack/apis/agents/agents\\.py$", - "^llama_stack/apis/batch_inference/batch_inference\\.py$", - "^llama_stack/apis/benchmarks/benchmarks\\.py$", - "^llama_stack/apis/common/content_types\\.py$", "^llama_stack/apis/common/training_types\\.py$", - "^llama_stack/apis/datasetio/datasetio\\.py$", - "^llama_stack/apis/datasets/datasets\\.py$", - "^llama_stack/apis/eval/eval\\.py$", - "^llama_stack/apis/files/files\\.py$", - "^llama_stack/apis/inference/inference\\.py$", - "^llama_stack/apis/inspect/inspect\\.py$", - "^llama_stack/apis/models/models\\.py$", - "^llama_stack/apis/post_training/post_training\\.py$", - "^llama_stack/apis/providers/providers\\.py$", - "^llama_stack/apis/resource\\.py$", - "^llama_stack/apis/safety/safety\\.py$", - "^llama_stack/apis/scoring/scoring\\.py$", - "^llama_stack/apis/scoring_functions/scoring_functions\\.py$", - "^llama_stack/apis/shields/shields\\.py$", - "^llama_stack/apis/synthetic_data_generation/synthetic_data_generation\\.py$", - "^llama_stack/apis/telemetry/telemetry\\.py$", - "^llama_stack/apis/tools/rag_tool\\.py$", - "^llama_stack/apis/tools/tools\\.py$", - "^llama_stack/apis/vector_dbs/vector_dbs\\.py$", - "^llama_stack/apis/vector_io/vector_io\\.py$", "^llama_stack/cli/download\\.py$", - "^llama_stack/cli/llama\\.py$", "^llama_stack/cli/stack/_build\\.py$", - "^llama_stack/cli/stack/list_providers\\.py$", "^llama_stack/distribution/build\\.py$", "^llama_stack/distribution/client\\.py$", - "^llama_stack/distribution/configure\\.py$", - "^llama_stack/distribution/library_client\\.py$", "^llama_stack/distribution/request_headers\\.py$", "^llama_stack/distribution/routers/", "^llama_stack/distribution/server/endpoints\\.py$", "^llama_stack/distribution/server/server\\.py$", - "^llama_stack/distribution/server/websocket_server\\.py$", "^llama_stack/distribution/stack\\.py$", "^llama_stack/distribution/store/registry\\.py$", - "^llama_stack/distribution/ui/page/playground/chat\\.py$", "^llama_stack/distribution/utils/exec\\.py$", "^llama_stack/distribution/utils/prompt_for_config\\.py$", - "^llama_stack/models/llama/datatypes\\.py$", "^llama_stack/models/llama/llama3/chat_format\\.py$", "^llama_stack/models/llama/llama3/interface\\.py$", - "^llama_stack/models/llama/llama3/prompt_templates/system_prompts\\.py$", "^llama_stack/models/llama/llama3/tokenizer\\.py$", "^llama_stack/models/llama/llama3/tool_utils\\.py$", "^llama_stack/models/llama/llama3_3/prompts\\.py$", - "^llama_stack/models/llama/llama4/", - "^llama_stack/models/llama/sku_list\\.py$", "^llama_stack/providers/inline/agents/meta_reference/", "^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", @@ -333,7 +299,6 @@ exclude = [ "^llama_stack/providers/utils/telemetry/dataset_mixin\\.py$", "^llama_stack/providers/utils/telemetry/trace_protocol\\.py$", "^llama_stack/providers/utils/telemetry/tracing\\.py$", - "^llama_stack/scripts/", "^llama_stack/strong_typing/auxiliary\\.py$", "^llama_stack/strong_typing/deserializer\\.py$", "^llama_stack/strong_typing/inspection\\.py$",