diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a7ece3b25..9b8b9a8df 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,13 +29,8 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.9.4 hooks: - # Run the linter with import sorting. - id: ruff - args: [ - --fix, - --exit-non-zero-on-fix, - --select, I, - ] + exclude: ^llama_stack/strong_typing/.*$ - id: ruff-format - repo: https://github.com/adamchainz/blacken-docs @@ -49,7 +44,13 @@ repos: rev: 0.5.26 hooks: - id: uv-export - args: ["--frozen", "--no-hashes", "--no-emit-project"] + args: [ + "--frozen", + "--no-hashes", + "--no-emit-project", + "--output-file=requirements.txt" + ] + files: ^pyproject\.toml$ - id: uv-sync # - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/.ruff.toml b/.ruff.toml deleted file mode 100644 index a913ae690..000000000 --- a/.ruff.toml +++ /dev/null @@ -1,37 +0,0 @@ -# Suggested config from pytorch that we can adapt -lint.select = ["B", "C", "E" , "F" , "N", "W", "B9"] - -line-length = 120 - -# C408 ignored because we like the dict keyword argument syntax -# E501 is not flexible enough, we're using B950 instead -# N812 ignored because import torch.nn.functional as F is PyTorch convention -# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP) -# E731 allow usage of assigning lambda expressions -# E701 let black auto-format statements on one line -# E704 let black auto-format statements on one line -lint.ignore = [ - "E203", "E305", "E402", "E501", "E721", "E741", "F405", "F821", "F841", - "C408", "E302", "W291", "E303", "N812", "N817", "E731", "E701", - # These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later. - "C901", "C405", "C414", "N803", "N999", "C403", "C416", "B028", "C419", "C401", "B023", - # shebang has extra meaning in fbcode lints, so I think it's not worth trying - # to line this up with executable bit - "EXE001", - # random naming hints don't need - "N802", - # these ignores are from flake8-bugbear; please fix! - "B007", "B008" -] - -exclude = [ - "./.git", - "./docs/*", - "./build", - "./scripts", - "./venv", - "*.pyi", - ".pre-commit-config.yaml", - "*.md", - ".flake8" -] diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 98270f7b8..2b6e1d11c 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -40,6 +40,286 @@ } ], "paths": { + "/v1/eval/tasks/{task_id}/evaluations": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/EvaluateResponse" + } + } + } + } + }, + "tags": [ + "Eval" + ], + "description": "", + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DeprecatedEvaluateRowsRequest" + } + } + }, + "required": true + }, + "deprecated": true + } + }, + "/v1/eval-tasks/{eval_task_id}": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "$ref": "#/components/schemas/Benchmark" + }, + { + "type": "null" + } + ] + } + } + } + } + }, + "tags": [ + "Benchmarks" + ], + "description": "", + "parameters": [ + { + "name": "eval_task_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "deprecated": true + } + }, + "/v1/eval/tasks/{task_id}/jobs/{job_id}": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "$ref": "#/components/schemas/JobStatus" + }, + { + "type": "null" + } + ] + } + } + } + } + }, + "tags": [ + "Eval" + ], + "description": "", + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "job_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "deprecated": true + }, + "delete": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "Eval" + ], + "description": "", + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "job_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "deprecated": true + } + }, + "/v1/eval/tasks/{task_id}/jobs/{job_id}/result": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/EvaluateResponse" + } + } + } + } + }, + "tags": [ + "Eval" + ], + "description": "", + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "job_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "deprecated": true + } + }, + "/v1/eval-tasks": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListBenchmarksResponse" + } + } + } + } + }, + "tags": [ + "Benchmarks" + ], + "description": "", + "parameters": [], + "deprecated": true + }, + "post": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "Benchmarks" + ], + "description": "", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DeprecatedRegisterEvalTaskRequest" + } + } + }, + "required": true + }, + "deprecated": true + } + }, + "/v1/eval/tasks/{task_id}/jobs": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Job" + } + } + } + } + }, + "tags": [ + "Eval" + ], + "description": "", + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DeprecatedRunEvalRequest" + } + } + }, + "required": true + }, + "deprecated": true + } + }, "/v1/datasetio/rows": { "get": { "responses": { @@ -530,7 +810,7 @@ } } }, - "/v1/eval/tasks/{task_id}/evaluations": { + "/v1/eval/benchmarks/{benchmark_id}/evaluations": { "post": { "responses": { "200": { @@ -550,7 +830,7 @@ "description": "", "parameters": [ { - "name": "task_id", + "name": "benchmark_id", "in": "path", "required": true, "schema": { @@ -670,6 +950,43 @@ ] } }, + "/v1/eval/benchmarks/{benchmark_id}": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "$ref": "#/components/schemas/Benchmark" + }, + { + "type": "null" + } + ] + } + } + } + } + }, + "tags": [ + "Benchmarks" + ], + "description": "", + "parameters": [ + { + "name": "benchmark_id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ] + } + }, "/v1/datasets/{dataset_id}": { "get": { "responses": { @@ -728,43 +1045,6 @@ ] } }, - "/v1/eval-tasks/{eval_task_id}": { - "get": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/EvalTask" - }, - { - "type": "null" - } - ] - } - } - } - } - }, - "tags": [ - "EvalTasks" - ], - "description": "", - "parameters": [ - { - "name": "eval_task_id", - "in": "path", - "required": true, - "schema": { - "type": "string" - } - } - ] - } - }, "/v1/models/{model_id}": { "get": { "responses": { @@ -1348,7 +1628,7 @@ } } }, - "/v1/eval/tasks/{task_id}/jobs/{job_id}": { + "/v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}": { "get": { "responses": { "200": { @@ -1375,7 +1655,7 @@ "description": "", "parameters": [ { - "name": "task_id", + "name": "benchmark_id", "in": "path", "required": true, "schema": { @@ -1404,7 +1684,7 @@ "description": "", "parameters": [ { - "name": "task_id", + "name": "benchmark_id", "in": "path", "required": true, "schema": { @@ -1422,7 +1702,7 @@ ] } }, - "/v1/eval/tasks/{task_id}/jobs/{job_id}/result": { + "/v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result": { "get": { "responses": { "200": { @@ -1442,7 +1722,7 @@ "description": "", "parameters": [ { - "name": "job_id", + "name": "benchmark_id", "in": "path", "required": true, "schema": { @@ -1450,7 +1730,7 @@ } }, { - "name": "task_id", + "name": "job_id", "in": "path", "required": true, "schema": { @@ -1460,6 +1740,49 @@ ] } }, + "/v1/eval/benchmarks": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListBenchmarksResponse" + } + } + } + } + }, + "tags": [ + "Benchmarks" + ], + "description": "", + "parameters": [] + }, + "post": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "Benchmarks" + ], + "description": "", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RegisterBenchmarkRequest" + } + } + }, + "required": true + } + } + }, "/v1/datasets": { "get": { "responses": { @@ -1503,49 +1826,6 @@ } } }, - "/v1/eval-tasks": { - "get": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ListEvalTasksResponse" - } - } - } - } - }, - "tags": [ - "EvalTasks" - ], - "description": "", - "parameters": [] - }, - "post": { - "responses": { - "200": { - "description": "OK" - } - }, - "tags": [ - "EvalTasks" - ], - "description": "", - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/RegisterEvalTaskRequest" - } - } - }, - "required": true - } - } - }, "/v1/models": { "get": { "responses": { @@ -2121,7 +2401,7 @@ ] } }, - "/v1/eval/tasks/{task_id}/jobs": { + "/v1/eval/benchmarks/{benchmark_id}/jobs": { "post": { "responses": { "200": { @@ -2141,7 +2421,7 @@ "description": "", "parameters": [ { - "name": "task_id", + "name": "benchmark_id", "in": "path", "required": true, "schema": { @@ -2365,84 +2645,227 @@ "jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema", "components": { "schemas": { - "AppendRowsRequest": { + "AgentCandidate": { "type": "object", "properties": { - "dataset_id": { - "type": "string" + "type": { + "type": "string", + "const": "agent", + "default": "agent" }, - "rows": { + "config": { + "$ref": "#/components/schemas/AgentConfig" + } + }, + "additionalProperties": false, + "required": [ + "type", + "config" + ], + "title": "AgentCandidate" + }, + "AgentConfig": { + "type": "object", + "properties": { + "sampling_params": { + "$ref": "#/components/schemas/SamplingParams" + }, + "input_shields": { "type": "array", "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] + "type": "string" + } + }, + "output_shields": { + "type": "array", + "items": { + "type": "string" + } + }, + "toolgroups": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AgentTool" + } + }, + "client_tools": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolDef" + } + }, + "tool_choice": { + "type": "string", + "enum": [ + "auto", + "required", + "none" + ], + "title": "ToolChoice", + "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.", + "deprecated": true + }, + "tool_prompt_format": { + "type": "string", + "enum": [ + "json", + "function_tag", + "python_list" + ], + "title": "ToolPromptFormat", + "description": "Prompt format for calling custom / zero shot tools.", + "deprecated": true + }, + "tool_config": { + "$ref": "#/components/schemas/ToolConfig" + }, + "max_infer_iters": { + "type": "integer", + "default": 10 + }, + "model": { + "type": "string" + }, + "instructions": { + "type": "string" + }, + "enable_session_persistence": { + "type": "boolean", + "default": false + }, + "response_format": { + "$ref": "#/components/schemas/ResponseFormat" + } + }, + "additionalProperties": false, + "required": [ + "model", + "instructions" + ], + "title": "AgentConfig" + }, + "AgentTool": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "args": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } + }, + "additionalProperties": false, + "required": [ + "name", + "args" + ], + "title": "AgentToolGroupWithArgs" + } + ] + }, + "AggregationFunctionType": { + "type": "string", + "enum": [ + "average", + "median", + "categorical_count", + "accuracy" + ], + "title": "AggregationFunctionType" + }, + "BasicScoringFnParams": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "basic", + "default": "basic" + }, + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" } } }, "additionalProperties": false, "required": [ - "dataset_id", - "rows" - ] + "type" + ], + "title": "BasicScoringFnParams" }, - "CompletionMessage": { + "BenchmarkConfig": { "type": "object", "properties": { - "role": { + "type": { "type": "string", - "const": "assistant", - "default": "assistant", - "description": "Must be \"assistant\" to identify this as the model's response" + "const": "benchmark", + "default": "benchmark" }, - "content": { - "$ref": "#/components/schemas/InterleavedContent", - "description": "The content of the model's response" + "eval_candidate": { + "$ref": "#/components/schemas/EvalCandidate" }, - "stop_reason": { - "type": "string", - "enum": [ - "end_of_turn", - "end_of_message", - "out_of_tokens" - ], - "description": "Reason why the model stopped generating. Options are: - `StopReason.end_of_turn`: The model finished generating the entire response. - `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response. - `StopReason.out_of_tokens`: The model ran out of token budget." + "scoring_params": { + "type": "object", + "additionalProperties": { + "$ref": "#/components/schemas/ScoringFnParams" + } }, - "tool_calls": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolCall" - }, - "description": "List of tool calls. Each tool call is a ToolCall object." + "num_examples": { + "type": "integer" } }, "additionalProperties": false, "required": [ - "role", - "content", - "stop_reason" + "type", + "eval_candidate", + "scoring_params" ], - "description": "A message containing the model's (assistant) response in a chat conversation." + "title": "BenchmarkConfig" + }, + "EvalCandidate": { + "oneOf": [ + { + "$ref": "#/components/schemas/ModelCandidate" + }, + { + "$ref": "#/components/schemas/AgentCandidate" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "model": "#/components/schemas/ModelCandidate", + "agent": "#/components/schemas/AgentCandidate" + } + } }, "GrammarResponseFormat": { "type": "object", @@ -2485,6 +2908,7 @@ "type", "bnf" ], + "title": "GrammarResponseFormat", "description": "Configuration for grammar-guided response generation." }, "GreedySamplingStrategy": { @@ -2499,7 +2923,8 @@ "additionalProperties": false, "required": [ "type" - ] + ], + "title": "GreedySamplingStrategy" }, "ImageContentItem": { "type": "object", @@ -2532,6 +2957,7 @@ "type", "image" ], + "title": "ImageContentItem", "description": "A image content item" }, "InterleavedContent": { @@ -2608,32 +3034,95 @@ "type", "json_schema" ], + "title": "JsonSchemaResponseFormat", "description": "Configuration for JSON schema-guided response generation." }, - "Message": { - "oneOf": [ - { - "$ref": "#/components/schemas/UserMessage" + "LLMAsJudgeScoringFnParams": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "llm_as_judge", + "default": "llm_as_judge" }, - { - "$ref": "#/components/schemas/SystemMessage" + "judge_model": { + "type": "string" }, - { - "$ref": "#/components/schemas/ToolResponseMessage" + "prompt_template": { + "type": "string" }, - { - "$ref": "#/components/schemas/CompletionMessage" + "judge_score_regexes": { + "type": "array", + "items": { + "type": "string" + } + }, + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + } } + }, + "additionalProperties": false, + "required": [ + "type", + "judge_model" ], - "discriminator": { - "propertyName": "role", - "mapping": { - "user": "#/components/schemas/UserMessage", - "system": "#/components/schemas/SystemMessage", - "tool": "#/components/schemas/ToolResponseMessage", - "assistant": "#/components/schemas/CompletionMessage" + "title": "LLMAsJudgeScoringFnParams" + }, + "ModelCandidate": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "model", + "default": "model" + }, + "model": { + "type": "string" + }, + "sampling_params": { + "$ref": "#/components/schemas/SamplingParams" + }, + "system_message": { + "$ref": "#/components/schemas/SystemMessage" } - } + }, + "additionalProperties": false, + "required": [ + "type", + "model", + "sampling_params" + ], + "title": "ModelCandidate" + }, + "RegexParserScoringFnParams": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "regex_parser", + "default": "regex_parser" + }, + "parsing_regexes": { + "type": "array", + "items": { + "type": "string" + } + }, + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + } + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "RegexParserScoringFnParams" }, "ResponseFormat": { "oneOf": [ @@ -2670,7 +3159,8 @@ "additionalProperties": false, "required": [ "strategy" - ] + ], + "title": "SamplingParams" }, "SamplingStrategy": { "oneOf": [ @@ -2693,6 +3183,27 @@ } } }, + "ScoringFnParams": { + "oneOf": [ + { + "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams" + }, + { + "$ref": "#/components/schemas/RegexParserScoringFnParams" + }, + { + "$ref": "#/components/schemas/BasicScoringFnParams" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "llm_as_judge": "#/components/schemas/LLMAsJudgeScoringFnParams", + "regex_parser": "#/components/schemas/RegexParserScoringFnParams", + "basic": "#/components/schemas/BasicScoringFnParams" + } + } + }, "SystemMessage": { "type": "object", "properties": { @@ -2712,6 +3223,7 @@ "role", "content" ], + "title": "SystemMessage", "description": "A system message providing instructions or context to the model." }, "TextContentItem": { @@ -2733,8 +3245,638 @@ "type", "text" ], + "title": "TextContentItem", "description": "A text content item" }, + "ToolConfig": { + "type": "object", + "properties": { + "tool_choice": { + "oneOf": [ + { + "type": "string", + "enum": [ + "auto", + "required", + "none" + ], + "title": "ToolChoice", + "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model." + }, + { + "type": "string" + } + ], + "default": "auto", + "description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto." + }, + "tool_prompt_format": { + "type": "string", + "enum": [ + "json", + "function_tag", + "python_list" + ], + "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls." + }, + "system_message_behavior": { + "type": "string", + "enum": [ + "append", + "replace" + ], + "description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.", + "default": "append" + } + }, + "additionalProperties": false, + "title": "ToolConfig", + "description": "Configuration for tool use." + }, + "ToolDef": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "parameters": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolParameter" + } + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "name" + ], + "title": "ToolDef" + }, + "ToolParameter": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "parameter_type": { + "type": "string" + }, + "description": { + "type": "string" + }, + "required": { + "type": "boolean", + "default": true + }, + "default": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "additionalProperties": false, + "required": [ + "name", + "parameter_type", + "description", + "required" + ], + "title": "ToolParameter" + }, + "TopKSamplingStrategy": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "top_k", + "default": "top_k" + }, + "top_k": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "type", + "top_k" + ], + "title": "TopKSamplingStrategy" + }, + "TopPSamplingStrategy": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "top_p", + "default": "top_p" + }, + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number", + "default": 0.95 + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "TopPSamplingStrategy" + }, + "URL": { + "type": "object", + "properties": { + "uri": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "uri" + ], + "title": "URL" + }, + "DeprecatedEvaluateRowsRequest": { + "type": "object", + "properties": { + "input_rows": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "scoring_functions": { + "type": "array", + "items": { + "type": "string" + } + }, + "task_config": { + "$ref": "#/components/schemas/BenchmarkConfig" + } + }, + "additionalProperties": false, + "required": [ + "input_rows", + "scoring_functions", + "task_config" + ], + "title": "DeprecatedEvaluateRowsRequest" + }, + "EvaluateResponse": { + "type": "object", + "properties": { + "generations": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "scores": { + "type": "object", + "additionalProperties": { + "$ref": "#/components/schemas/ScoringResult" + } + } + }, + "additionalProperties": false, + "required": [ + "generations", + "scores" + ], + "title": "EvaluateResponse" + }, + "ScoringResult": { + "type": "object", + "properties": { + "score_rows": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "aggregated_results": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "score_rows", + "aggregated_results" + ], + "title": "ScoringResult" + }, + "Benchmark": { + "type": "object", + "properties": { + "identifier": { + "type": "string" + }, + "provider_resource_id": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "benchmark", + "default": "benchmark" + }, + "dataset_id": { + "type": "string" + }, + "scoring_functions": { + "type": "array", + "items": { + "type": "string" + } + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "identifier", + "provider_resource_id", + "provider_id", + "type", + "dataset_id", + "scoring_functions", + "metadata" + ], + "title": "Benchmark" + }, + "JobStatus": { + "type": "string", + "enum": [ + "completed", + "in_progress", + "failed", + "scheduled" + ], + "title": "JobStatus" + }, + "ListBenchmarksResponse": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Benchmark" + } + } + }, + "additionalProperties": false, + "required": [ + "data" + ], + "title": "ListBenchmarksResponse" + }, + "DeprecatedRegisterEvalTaskRequest": { + "type": "object", + "properties": { + "eval_task_id": { + "type": "string" + }, + "dataset_id": { + "type": "string" + }, + "scoring_functions": { + "type": "array", + "items": { + "type": "string" + } + }, + "provider_benchmark_id": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "eval_task_id", + "dataset_id", + "scoring_functions" + ], + "title": "DeprecatedRegisterEvalTaskRequest" + }, + "DeprecatedRunEvalRequest": { + "type": "object", + "properties": { + "task_config": { + "$ref": "#/components/schemas/BenchmarkConfig" + } + }, + "additionalProperties": false, + "required": [ + "task_config" + ], + "title": "DeprecatedRunEvalRequest" + }, + "Job": { + "type": "object", + "properties": { + "job_id": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "job_id" + ], + "title": "Job" + }, + "AppendRowsRequest": { + "type": "object", + "properties": { + "dataset_id": { + "type": "string" + }, + "rows": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + } + }, + "additionalProperties": false, + "required": [ + "dataset_id", + "rows" + ], + "title": "AppendRowsRequest" + }, + "CompletionMessage": { + "type": "object", + "properties": { + "role": { + "type": "string", + "const": "assistant", + "default": "assistant", + "description": "Must be \"assistant\" to identify this as the model's response" + }, + "content": { + "$ref": "#/components/schemas/InterleavedContent", + "description": "The content of the model's response" + }, + "stop_reason": { + "type": "string", + "enum": [ + "end_of_turn", + "end_of_message", + "out_of_tokens" + ], + "description": "Reason why the model stopped generating. Options are: - `StopReason.end_of_turn`: The model finished generating the entire response. - `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response. - `StopReason.out_of_tokens`: The model ran out of token budget." + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolCall" + }, + "description": "List of tool calls. Each tool call is a ToolCall object." + } + }, + "additionalProperties": false, + "required": [ + "role", + "content", + "stop_reason" + ], + "title": "CompletionMessage", + "description": "A message containing the model's (assistant) response in a chat conversation." + }, + "Message": { + "oneOf": [ + { + "$ref": "#/components/schemas/UserMessage" + }, + { + "$ref": "#/components/schemas/SystemMessage" + }, + { + "$ref": "#/components/schemas/ToolResponseMessage" + }, + { + "$ref": "#/components/schemas/CompletionMessage" + } + ], + "discriminator": { + "propertyName": "role", + "mapping": { + "user": "#/components/schemas/UserMessage", + "system": "#/components/schemas/SystemMessage", + "tool": "#/components/schemas/ToolResponseMessage", + "assistant": "#/components/schemas/CompletionMessage" + } + } + }, "ToolCall": { "type": "object", "properties": { @@ -2750,7 +3892,8 @@ "wolfram_alpha", "photogen", "code_interpreter" - ] + ], + "title": "BuiltinTool" }, { "type": "string" @@ -2829,7 +3972,8 @@ "call_id", "tool_name", "arguments" - ] + ], + "title": "ToolCall" }, "ToolDefinition": { "type": "object", @@ -2843,7 +3987,8 @@ "wolfram_alpha", "photogen", "code_interpreter" - ] + ], + "title": "BuiltinTool" }, { "type": "string" @@ -2863,7 +4008,8 @@ "additionalProperties": false, "required": [ "tool_name" - ] + ], + "title": "ToolDefinition" }, "ToolParamDefinition": { "type": "object", @@ -2904,7 +4050,8 @@ "additionalProperties": false, "required": [ "param_type" - ] + ], + "title": "ToolParamDefinition" }, "ToolResponseMessage": { "type": "object", @@ -2928,7 +4075,8 @@ "wolfram_alpha", "photogen", "code_interpreter" - ] + ], + "title": "BuiltinTool" }, { "type": "string" @@ -2948,59 +4096,9 @@ "tool_name", "content" ], + "title": "ToolResponseMessage", "description": "A message representing the result of a tool invocation." }, - "TopKSamplingStrategy": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "top_k", - "default": "top_k" - }, - "top_k": { - "type": "integer" - } - }, - "additionalProperties": false, - "required": [ - "type", - "top_k" - ] - }, - "TopPSamplingStrategy": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "top_p", - "default": "top_p" - }, - "temperature": { - "type": "number" - }, - "top_p": { - "type": "number", - "default": 0.95 - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, - "URL": { - "type": "object", - "properties": { - "uri": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "uri" - ] - }, "UserMessage": { "type": "object", "properties": { @@ -3024,6 +4122,7 @@ "role", "content" ], + "title": "UserMessage", "description": "A message from the user in a chat conversation." }, "BatchChatCompletionRequest": { @@ -3054,8 +4153,10 @@ "type": "string", "enum": [ "auto", - "required" + "required", + "none" ], + "title": "ToolChoice", "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model." }, "tool_prompt_format": { @@ -3065,6 +4166,7 @@ "function_tag", "python_list" ], + "title": "ToolPromptFormat", "description": "Prompt format for calling custom / zero shot tools." }, "response_format": { @@ -3079,14 +4181,16 @@ "description": "How many tokens (for each position) to return log probabilities for." } }, - "additionalProperties": false + "additionalProperties": false, + "title": "LogProbConfig" } }, "additionalProperties": false, "required": [ "model", "messages_batch" - ] + ], + "title": "BatchChatCompletionRequest" }, "BatchChatCompletionResponse": { "type": "object", @@ -3101,7 +4205,8 @@ "additionalProperties": false, "required": [ "batch" - ] + ], + "title": "BatchChatCompletionResponse" }, "ChatCompletionResponse": { "type": "object", @@ -3128,6 +4233,7 @@ "required": [ "completion_message" ], + "title": "ChatCompletionResponse", "description": "Response from a chat completion request." }, "MetricEvent": { @@ -3196,7 +4302,8 @@ "metric", "value", "unit" - ] + ], + "title": "MetricEvent" }, "TokenLogProbs": { "type": "object", @@ -3213,6 +4320,7 @@ "required": [ "logprobs_by_token" ], + "title": "TokenLogProbs", "description": "Log probabilities for generated tokens." }, "BatchCompletionRequest": { @@ -3242,14 +4350,16 @@ "description": "How many tokens (for each position) to return log probabilities for." } }, - "additionalProperties": false + "additionalProperties": false, + "title": "LogProbConfig" } }, "additionalProperties": false, "required": [ "model", "content_batch" - ] + ], + "title": "BatchCompletionRequest" }, "BatchCompletionResponse": { "type": "object", @@ -3264,7 +4374,8 @@ "additionalProperties": false, "required": [ "batch" - ] + ], + "title": "BatchCompletionResponse" }, "CompletionResponse": { "type": "object", @@ -3295,6 +4406,7 @@ "content", "stop_reason" ], + "title": "CompletionResponse", "description": "Response from a completion request." }, "CancelTrainingJobRequest": { @@ -3307,44 +4419,8 @@ "additionalProperties": false, "required": [ "job_uuid" - ] - }, - "ToolConfig": { - "type": "object", - "properties": { - "tool_choice": { - "type": "string", - "enum": [ - "auto", - "required" - ], - "description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.", - "default": "auto" - }, - "tool_prompt_format": { - "type": "string", - "enum": [ - "json", - "function_tag", - "python_list" - ], - "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls." - }, - "system_message_behavior": { - "type": "string", - "enum": [ - "append", - "replace" - ], - "description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.", - "default": "append" - } - }, - "additionalProperties": false, - "required": [ - "system_message_behavior" ], - "description": "Configuration for tool use." + "title": "CancelTrainingJobRequest" }, "ChatCompletionRequest": { "type": "object", @@ -3375,7 +4451,8 @@ "type": "string", "enum": [ "auto", - "required" + "required", + "none" ], "description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. .. deprecated:: Use tool_config instead." }, @@ -3417,7 +4494,8 @@ "required": [ "model_id", "messages" - ] + ], + "title": "ChatCompletionRequest" }, "ChatCompletionResponseEvent": { "type": "object", @@ -3457,6 +4535,7 @@ "event_type", "delta" ], + "title": "ChatCompletionResponseEvent", "description": "An event during chat completion generation." }, "ChatCompletionResponseStreamChunk": { @@ -3477,6 +4556,7 @@ "required": [ "event" ], + "title": "ChatCompletionResponseStreamChunk", "description": "A chunk of a streamed chat completion response." }, "ContentDelta": { @@ -3517,7 +4597,8 @@ "required": [ "type", "image" - ] + ], + "title": "ImageDelta" }, "TextDelta": { "type": "object", @@ -3535,7 +4616,8 @@ "required": [ "type", "text" - ] + ], + "title": "TextDelta" }, "ToolCallDelta": { "type": "object", @@ -3562,7 +4644,8 @@ "in_progress", "failed", "succeeded" - ] + ], + "title": "ToolCallParseStatus" } }, "additionalProperties": false, @@ -3570,7 +4653,8 @@ "type", "tool_call", "parse_status" - ] + ], + "title": "ToolCallDelta" }, "CompletionRequest": { "type": "object", @@ -3612,7 +4696,8 @@ "required": [ "model_id", "content" - ] + ], + "title": "CompletionRequest" }, "CompletionResponseStreamChunk": { "type": "object", @@ -3642,220 +4727,9 @@ "required": [ "delta" ], + "title": "CompletionResponseStreamChunk", "description": "A chunk of a streamed completion response." }, - "AgentConfig": { - "type": "object", - "properties": { - "sampling_params": { - "$ref": "#/components/schemas/SamplingParams" - }, - "input_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "output_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "toolgroups": { - "type": "array", - "items": { - "$ref": "#/components/schemas/AgentTool" - } - }, - "client_tools": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolDef" - } - }, - "tool_choice": { - "type": "string", - "enum": [ - "auto", - "required" - ], - "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model." - }, - "tool_prompt_format": { - "type": "string", - "enum": [ - "json", - "function_tag", - "python_list" - ], - "description": "Prompt format for calling custom / zero shot tools." - }, - "tool_config": { - "$ref": "#/components/schemas/ToolConfig" - }, - "max_infer_iters": { - "type": "integer", - "default": 10 - }, - "model": { - "type": "string" - }, - "instructions": { - "type": "string" - }, - "enable_session_persistence": { - "type": "boolean" - }, - "response_format": { - "$ref": "#/components/schemas/ResponseFormat" - } - }, - "additionalProperties": false, - "required": [ - "model", - "instructions", - "enable_session_persistence" - ] - }, - "AgentTool": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "args": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "name", - "args" - ] - } - ] - }, - "ToolDef": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "description": { - "type": "string" - }, - "parameters": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolParameter" - } - }, - "metadata": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "name" - ] - }, - "ToolParameter": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "parameter_type": { - "type": "string" - }, - "description": { - "type": "string" - }, - "required": { - "type": "boolean", - "default": true - }, - "default": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "additionalProperties": false, - "required": [ - "name", - "parameter_type", - "description", - "required" - ] - }, "CreateAgentRequest": { "type": "object", "properties": { @@ -3866,7 +4740,8 @@ "additionalProperties": false, "required": [ "agent_config" - ] + ], + "title": "CreateAgentRequest" }, "AgentCreateResponse": { "type": "object", @@ -3878,7 +4753,8 @@ "additionalProperties": false, "required": [ "agent_id" - ] + ], + "title": "AgentCreateResponse" }, "CreateAgentSessionRequest": { "type": "object", @@ -3890,7 +4766,8 @@ "additionalProperties": false, "required": [ "session_name" - ] + ], + "title": "CreateAgentSessionRequest" }, "AgentSessionCreateResponse": { "type": "object", @@ -3902,7 +4779,8 @@ "additionalProperties": false, "required": [ "session_id" - ] + ], + "title": "AgentSessionCreateResponse" }, "CreateAgentTurnRequest": { "type": "object", @@ -3955,7 +4833,8 @@ "required": [ "content", "mime_type" - ] + ], + "title": "Document" } }, "toolgroups": { @@ -3971,7 +4850,8 @@ "additionalProperties": false, "required": [ "messages" - ] + ], + "title": "CreateAgentTurnRequest" }, "InferenceStep": { "type": "object", @@ -4005,7 +4885,8 @@ "step_id", "step_type", "model_response" - ] + ], + "title": "InferenceStep" }, "MemoryRetrievalStep": { "type": "object", @@ -4043,7 +4924,8 @@ "step_type", "vector_db_ids", "inserted_context" - ] + ], + "title": "MemoryRetrievalStep" }, "SafetyViolation": { "type": "object", @@ -4084,7 +4966,8 @@ "required": [ "violation_level", "metadata" - ] + ], + "title": "SafetyViolation" }, "ShieldCallStep": { "type": "object", @@ -4117,7 +5000,8 @@ "turn_id", "step_id", "step_type" - ] + ], + "title": "ShieldCallStep" }, "ToolExecutionStep": { "type": "object", @@ -4161,7 +5045,8 @@ "step_type", "tool_calls", "tool_responses" - ] + ], + "title": "ToolExecutionStep" }, "ToolResponse": { "type": "object", @@ -4178,7 +5063,8 @@ "wolfram_alpha", "photogen", "code_interpreter" - ] + ], + "title": "BuiltinTool" }, { "type": "string" @@ -4194,7 +5080,8 @@ "call_id", "tool_name", "content" - ] + ], + "title": "ToolResponse" }, "Turn": { "type": "object", @@ -4281,7 +5168,8 @@ "required": [ "content", "mime_type" - ] + ], + "title": "Attachment" } }, "started_at": { @@ -4302,6 +5190,7 @@ "output_message", "started_at" ], + "title": "Turn", "description": "A single turn in an interaction with an Agentic System." }, "ViolationLevel": { @@ -4310,7 +5199,8 @@ "info", "warn", "error" - ] + ], + "title": "ViolationLevel" }, "AgentTurnResponseEvent": { "type": "object", @@ -4322,7 +5212,8 @@ "additionalProperties": false, "required": [ "payload" - ] + ], + "title": "AgentTurnResponseEvent" }, "AgentTurnResponseEventPayload": { "oneOf": [ @@ -4368,7 +5259,8 @@ "tool_execution", "shield_call", "memory_retrieval" - ] + ], + "title": "StepType" }, "step_id": { "type": "string" @@ -4405,7 +5297,8 @@ "step_type", "step_id", "step_details" - ] + ], + "title": "AgentTurnResponseStepCompletePayload" }, "AgentTurnResponseStepProgressPayload": { "type": "object", @@ -4422,7 +5315,8 @@ "tool_execution", "shield_call", "memory_retrieval" - ] + ], + "title": "StepType" }, "step_id": { "type": "string" @@ -4437,7 +5331,8 @@ "step_type", "step_id", "delta" - ] + ], + "title": "AgentTurnResponseStepProgressPayload" }, "AgentTurnResponseStepStartPayload": { "type": "object", @@ -4454,7 +5349,8 @@ "tool_execution", "shield_call", "memory_retrieval" - ] + ], + "title": "StepType" }, "step_id": { "type": "string" @@ -4490,7 +5386,8 @@ "event_type", "step_type", "step_id" - ] + ], + "title": "AgentTurnResponseStepStartPayload" }, "AgentTurnResponseStreamChunk": { "type": "object", @@ -4503,6 +5400,7 @@ "required": [ "event" ], + "title": "AgentTurnResponseStreamChunk", "description": "streamed agent turn completion response." }, "AgentTurnResponseTurnCompletePayload": { @@ -4521,7 +5419,8 @@ "required": [ "event_type", "turn" - ] + ], + "title": "AgentTurnResponseTurnCompletePayload" }, "AgentTurnResponseTurnStartPayload": { "type": "object", @@ -4539,7 +5438,8 @@ "required": [ "event_type", "turn_id" - ] + ], + "title": "AgentTurnResponseTurnStartPayload" }, "EmbeddingsRequest": { "type": "object", @@ -4560,7 +5460,8 @@ "required": [ "model_id", "contents" - ] + ], + "title": "EmbeddingsRequest" }, "EmbeddingsResponse": { "type": "object", @@ -4580,243 +5481,9 @@ "required": [ "embeddings" ], + "title": "EmbeddingsResponse", "description": "Response containing generated embeddings." }, - "AgentCandidate": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "agent", - "default": "agent" - }, - "config": { - "$ref": "#/components/schemas/AgentConfig" - } - }, - "additionalProperties": false, - "required": [ - "type", - "config" - ] - }, - "AggregationFunctionType": { - "type": "string", - "enum": [ - "average", - "median", - "categorical_count", - "accuracy" - ] - }, - "AppEvalTaskConfig": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "app", - "default": "app" - }, - "eval_candidate": { - "$ref": "#/components/schemas/EvalCandidate" - }, - "scoring_params": { - "type": "object", - "additionalProperties": { - "$ref": "#/components/schemas/ScoringFnParams" - } - }, - "num_examples": { - "type": "integer" - } - }, - "additionalProperties": false, - "required": [ - "type", - "eval_candidate", - "scoring_params" - ] - }, - "BasicScoringFnParams": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "basic", - "default": "basic" - }, - "aggregation_functions": { - "type": "array", - "items": { - "$ref": "#/components/schemas/AggregationFunctionType" - } - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, - "BenchmarkEvalTaskConfig": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "benchmark", - "default": "benchmark" - }, - "eval_candidate": { - "$ref": "#/components/schemas/EvalCandidate" - }, - "num_examples": { - "type": "integer" - } - }, - "additionalProperties": false, - "required": [ - "type", - "eval_candidate" - ] - }, - "EvalCandidate": { - "oneOf": [ - { - "$ref": "#/components/schemas/ModelCandidate" - }, - { - "$ref": "#/components/schemas/AgentCandidate" - } - ], - "discriminator": { - "propertyName": "type", - "mapping": { - "model": "#/components/schemas/ModelCandidate", - "agent": "#/components/schemas/AgentCandidate" - } - } - }, - "EvalTaskConfig": { - "oneOf": [ - { - "$ref": "#/components/schemas/BenchmarkEvalTaskConfig" - }, - { - "$ref": "#/components/schemas/AppEvalTaskConfig" - } - ], - "discriminator": { - "propertyName": "type", - "mapping": { - "benchmark": "#/components/schemas/BenchmarkEvalTaskConfig", - "app": "#/components/schemas/AppEvalTaskConfig" - } - } - }, - "LLMAsJudgeScoringFnParams": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "llm_as_judge", - "default": "llm_as_judge" - }, - "judge_model": { - "type": "string" - }, - "prompt_template": { - "type": "string" - }, - "judge_score_regexes": { - "type": "array", - "items": { - "type": "string" - } - }, - "aggregation_functions": { - "type": "array", - "items": { - "$ref": "#/components/schemas/AggregationFunctionType" - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "judge_model" - ] - }, - "ModelCandidate": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "model", - "default": "model" - }, - "model": { - "type": "string" - }, - "sampling_params": { - "$ref": "#/components/schemas/SamplingParams" - }, - "system_message": { - "$ref": "#/components/schemas/SystemMessage" - } - }, - "additionalProperties": false, - "required": [ - "type", - "model", - "sampling_params" - ] - }, - "RegexParserScoringFnParams": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "regex_parser", - "default": "regex_parser" - }, - "parsing_regexes": { - "type": "array", - "items": { - "type": "string" - } - }, - "aggregation_functions": { - "type": "array", - "items": { - "$ref": "#/components/schemas/AggregationFunctionType" - } - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, - "ScoringFnParams": { - "oneOf": [ - { - "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams" - }, - { - "$ref": "#/components/schemas/RegexParserScoringFnParams" - }, - { - "$ref": "#/components/schemas/BasicScoringFnParams" - } - ], - "discriminator": { - "propertyName": "type", - "mapping": { - "llm_as_judge": "#/components/schemas/LLMAsJudgeScoringFnParams", - "regex_parser": "#/components/schemas/RegexParserScoringFnParams", - "basic": "#/components/schemas/BasicScoringFnParams" - } - } - }, "EvaluateRowsRequest": { "type": "object", "properties": { @@ -4855,7 +5522,7 @@ } }, "task_config": { - "$ref": "#/components/schemas/EvalTaskConfig" + "$ref": "#/components/schemas/BenchmarkConfig" } }, "additionalProperties": false, @@ -4863,114 +5530,8 @@ "input_rows", "scoring_functions", "task_config" - ] - }, - "EvaluateResponse": { - "type": "object", - "properties": { - "generations": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "scores": { - "type": "object", - "additionalProperties": { - "$ref": "#/components/schemas/ScoringResult" - } - } - }, - "additionalProperties": false, - "required": [ - "generations", - "scores" - ] - }, - "ScoringResult": { - "type": "object", - "properties": { - "score_rows": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "aggregated_results": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "score_rows", - "aggregated_results" - ] + ], + "title": "EvaluateRowsRequest" }, "Session": { "type": "object", @@ -4999,6 +5560,7 @@ "turns", "started_at" ], + "title": "Session", "description": "A single session of an interaction with an Agentic System." }, "AgentStepResponse": { @@ -5033,7 +5595,8 @@ "additionalProperties": false, "required": [ "step" - ] + ], + "title": "AgentStepResponse" }, "AgentTurnInputType": { "type": "object", @@ -5047,7 +5610,8 @@ "additionalProperties": false, "required": [ "type" - ] + ], + "title": "AgentTurnInputType" }, "ArrayType": { "type": "object", @@ -5061,7 +5625,8 @@ "additionalProperties": false, "required": [ "type" - ] + ], + "title": "ArrayType" }, "BooleanType": { "type": "object", @@ -5075,7 +5640,8 @@ "additionalProperties": false, "required": [ "type" - ] + ], + "title": "BooleanType" }, "ChatCompletionInputType": { "type": "object", @@ -5089,7 +5655,8 @@ "additionalProperties": false, "required": [ "type" - ] + ], + "title": "ChatCompletionInputType" }, "CompletionInputType": { "type": "object", @@ -5103,7 +5670,8 @@ "additionalProperties": false, "required": [ "type" - ] + ], + "title": "CompletionInputType" }, "Dataset": { "type": "object", @@ -5166,7 +5734,8 @@ "dataset_schema", "url", "metadata" - ] + ], + "title": "Dataset" }, "JsonType": { "type": "object", @@ -5180,7 +5749,8 @@ "additionalProperties": false, "required": [ "type" - ] + ], + "title": "JsonType" }, "NumberType": { "type": "object", @@ -5194,7 +5764,8 @@ "additionalProperties": false, "required": [ "type" - ] + ], + "title": "NumberType" }, "ObjectType": { "type": "object", @@ -5208,7 +5779,8 @@ "additionalProperties": false, "required": [ "type" - ] + ], + "title": "ObjectType" }, "ParamType": { "oneOf": [ @@ -5271,7 +5843,8 @@ "additionalProperties": false, "required": [ "type" - ] + ], + "title": "StringType" }, "UnionType": { "type": "object", @@ -5285,70 +5858,8 @@ "additionalProperties": false, "required": [ "type" - ] - }, - "EvalTask": { - "type": "object", - "properties": { - "identifier": { - "type": "string" - }, - "provider_resource_id": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "eval_task", - "default": "eval_task" - }, - "dataset_id": { - "type": "string" - }, - "scoring_functions": { - "type": "array", - "items": { - "type": "string" - } - }, - "metadata": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "identifier", - "provider_resource_id", - "provider_id", - "type", - "dataset_id", - "scoring_functions", - "metadata" - ] + ], + "title": "UnionType" }, "Model": { "type": "object", @@ -5405,14 +5916,16 @@ "type", "metadata", "model_type" - ] + ], + "title": "Model" }, "ModelType": { "type": "string", "enum": [ "llm", "embedding" - ] + ], + "title": "ModelType" }, "PaginatedRowsResult": { "type": "object", @@ -5456,7 +5969,8 @@ "required": [ "rows", "total_count" - ] + ], + "title": "PaginatedRowsResult" }, "ScoringFn": { "type": "object", @@ -5518,7 +6032,8 @@ "type", "metadata", "return_type" - ] + ], + "title": "ScoringFn" }, "Shield": { "type": "object", @@ -5570,6 +6085,7 @@ "provider_id", "type" ], + "title": "Shield", "description": "A safety shield resource that can be used to check content" }, "Span": { @@ -5627,14 +6143,16 @@ "trace_id", "name", "start_time" - ] + ], + "title": "Span" }, "SpanStatus": { "type": "string", "enum": [ "ok", "error" - ] + ], + "title": "SpanStatus" }, "SpanWithStatus": { "type": "object", @@ -5694,7 +6212,8 @@ "trace_id", "name", "start_time" - ] + ], + "title": "SpanWithStatus" }, "QuerySpanTreeResponse": { "type": "object", @@ -5709,7 +6228,8 @@ "additionalProperties": false, "required": [ "data" - ] + ], + "title": "QuerySpanTreeResponse" }, "Tool": { "type": "object", @@ -5779,7 +6299,8 @@ "tool_host", "description", "parameters" - ] + ], + "title": "Tool" }, "ToolHost": { "type": "string", @@ -5787,7 +6308,8 @@ "distribution", "client", "model_context_protocol" - ] + ], + "title": "ToolHost" }, "ToolGroup": { "type": "object", @@ -5841,7 +6363,8 @@ "provider_resource_id", "provider_id", "type" - ] + ], + "title": "ToolGroup" }, "Trace": { "type": "object", @@ -5866,10 +6389,12 @@ "trace_id", "root_span_id", "start_time" - ] + ], + "title": "Trace" }, "Checkpoint": { - "description": "Checkpoint created during training runs" + "description": "Checkpoint created during training runs", + "title": "Checkpoint" }, "PostTrainingJobArtifactsResponse": { "type": "object", @@ -5889,17 +6414,9 @@ "job_uuid", "checkpoints" ], + "title": "PostTrainingJobArtifactsResponse", "description": "Artifacts of a finetuning job." }, - "JobStatus": { - "type": "string", - "enum": [ - "completed", - "in_progress", - "failed", - "scheduled" - ] - }, "PostTrainingJobStatusResponse": { "type": "object", "properties": { @@ -5959,6 +6476,7 @@ "status", "checkpoints" ], + "title": "PostTrainingJobStatusResponse", "description": "Status of a finetuning job." }, "ListPostTrainingJobsResponse": { @@ -5976,14 +6494,16 @@ "additionalProperties": false, "required": [ "job_uuid" - ] + ], + "title": "PostTrainingJob" } } }, "additionalProperties": false, "required": [ "data" - ] + ], + "title": "ListPostTrainingJobsResponse" }, "VectorDB": { "type": "object", @@ -6017,7 +6537,8 @@ "type", "embedding_model", "embedding_dimension" - ] + ], + "title": "VectorDB" }, "HealthInfo": { "type": "object", @@ -6029,7 +6550,8 @@ "additionalProperties": false, "required": [ "status" - ] + ], + "title": "HealthInfo" }, "RAGDocument": { "type": "object", @@ -6090,7 +6612,8 @@ "document_id", "content", "metadata" - ] + ], + "title": "RAGDocument" }, "InsertRequest": { "type": "object", @@ -6113,7 +6636,8 @@ "documents", "vector_db_id", "chunk_size_in_tokens" - ] + ], + "title": "InsertRequest" }, "InsertChunksRequest": { "type": "object", @@ -6159,7 +6683,8 @@ "required": [ "content", "metadata" - ] + ], + "title": "Chunk" } }, "ttl_seconds": { @@ -6170,7 +6695,8 @@ "required": [ "vector_db_id", "chunks" - ] + ], + "title": "InsertChunksRequest" }, "InvokeToolRequest": { "type": "object", @@ -6208,7 +6734,8 @@ "required": [ "tool_name", "kwargs" - ] + ], + "title": "InvokeToolRequest" }, "ToolInvocationResult": { "type": "object", @@ -6226,7 +6753,8 @@ "additionalProperties": false, "required": [ "content" - ] + ], + "title": "ToolInvocationResult" }, "ListDatasetsResponse": { "type": "object", @@ -6241,22 +6769,8 @@ "additionalProperties": false, "required": [ "data" - ] - }, - "ListEvalTasksResponse": { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/components/schemas/EvalTask" - } - } - }, - "additionalProperties": false, - "required": [ - "data" - ] + ], + "title": "ListDatasetsResponse" }, "ListModelsResponse": { "type": "object", @@ -6271,7 +6785,8 @@ "additionalProperties": false, "required": [ "data" - ] + ], + "title": "ListModelsResponse" }, "ProviderInfo": { "type": "object", @@ -6291,7 +6806,8 @@ "api", "provider_id", "provider_type" - ] + ], + "title": "ProviderInfo" }, "ListProvidersResponse": { "type": "object", @@ -6306,7 +6822,8 @@ "additionalProperties": false, "required": [ "data" - ] + ], + "title": "ListProvidersResponse" }, "RouteInfo": { "type": "object", @@ -6329,7 +6846,8 @@ "route", "method", "provider_types" - ] + ], + "title": "RouteInfo" }, "ListRoutesResponse": { "type": "object", @@ -6344,7 +6862,8 @@ "additionalProperties": false, "required": [ "data" - ] + ], + "title": "ListRoutesResponse" }, "ListScoringFunctionsResponse": { "type": "object", @@ -6359,7 +6878,8 @@ "additionalProperties": false, "required": [ "data" - ] + ], + "title": "ListScoringFunctionsResponse" }, "ListShieldsResponse": { "type": "object", @@ -6374,7 +6894,8 @@ "additionalProperties": false, "required": [ "data" - ] + ], + "title": "ListShieldsResponse" }, "ListToolGroupsResponse": { "type": "object", @@ -6389,7 +6910,8 @@ "additionalProperties": false, "required": [ "data" - ] + ], + "title": "ListToolGroupsResponse" }, "ListToolsResponse": { "type": "object", @@ -6404,7 +6926,8 @@ "additionalProperties": false, "required": [ "data" - ] + ], + "title": "ListToolsResponse" }, "ListVectorDBsResponse": { "type": "object", @@ -6419,7 +6942,8 @@ "additionalProperties": false, "required": [ "data" - ] + ], + "title": "ListVectorDBsResponse" }, "Event": { "oneOf": [ @@ -6451,7 +6975,8 @@ "warn", "error", "critical" - ] + ], + "title": "LogSeverity" }, "SpanEndPayload": { "type": "object", @@ -6469,7 +6994,8 @@ "required": [ "type", "status" - ] + ], + "title": "SpanEndPayload" }, "SpanStartPayload": { "type": "object", @@ -6490,7 +7016,8 @@ "required": [ "type", "name" - ] + ], + "title": "SpanStartPayload" }, "StructuredLogEvent": { "type": "object", @@ -6543,7 +7070,8 @@ "timestamp", "type", "payload" - ] + ], + "title": "StructuredLogEvent" }, "StructuredLogPayload": { "oneOf": [ @@ -6617,7 +7145,8 @@ "type", "message", "severity" - ] + ], + "title": "UnstructuredLogEvent" }, "LogEventRequest": { "type": "object", @@ -6633,7 +7162,8 @@ "required": [ "event", "ttl_seconds" - ] + ], + "title": "LogEventRequest" }, "DPOAlignmentConfig": { "type": "object", @@ -6657,7 +7187,8 @@ "reward_clip", "epsilon", "gamma" - ] + ], + "title": "DPOAlignmentConfig" }, "DataConfig": { "type": "object", @@ -6692,14 +7223,16 @@ "batch_size", "shuffle", "data_format" - ] + ], + "title": "DataConfig" }, "DatasetFormat": { "type": "string", "enum": [ "instruct", "dialog" - ] + ], + "title": "DatasetFormat" }, "EfficiencyConfig": { "type": "object", @@ -6721,7 +7254,8 @@ "default": false } }, - "additionalProperties": false + "additionalProperties": false, + "title": "EfficiencyConfig" }, "OptimizerConfig": { "type": "object", @@ -6745,7 +7279,8 @@ "lr", "weight_decay", "num_warmup_steps" - ] + ], + "title": "OptimizerConfig" }, "OptimizerType": { "type": "string", @@ -6753,7 +7288,8 @@ "adam", "adamw", "sgd" - ] + ], + "title": "OptimizerType" }, "TrainingConfig": { "type": "object", @@ -6792,7 +7328,8 @@ "max_validation_steps", "data_config", "optimizer_config" - ] + ], + "title": "TrainingConfig" }, "PreferenceOptimizeRequest": { "type": "object", @@ -6868,7 +7405,8 @@ "training_config", "hyperparam_search_config", "logger_config" - ] + ], + "title": "PreferenceOptimizeRequest" }, "PostTrainingJob": { "type": "object", @@ -6880,7 +7418,8 @@ "additionalProperties": false, "required": [ "job_uuid" - ] + ], + "title": "PostTrainingJob" }, "DefaultRAGQueryGeneratorConfig": { "type": "object", @@ -6899,7 +7438,8 @@ "required": [ "type", "separator" - ] + ], + "title": "DefaultRAGQueryGeneratorConfig" }, "LLMRAGQueryGeneratorConfig": { "type": "object", @@ -6921,7 +7461,8 @@ "type", "model", "template" - ] + ], + "title": "LLMRAGQueryGeneratorConfig" }, "RAGQueryConfig": { "type": "object", @@ -6943,7 +7484,8 @@ "query_generator_config", "max_tokens_in_context", "max_chunks" - ] + ], + "title": "RAGQueryConfig" }, "RAGQueryGeneratorConfig": { "oneOf": [ @@ -6982,7 +7524,8 @@ "required": [ "content", "vector_db_ids" - ] + ], + "title": "QueryRequest" }, "RAGQueryResult": { "type": "object", @@ -6991,7 +7534,8 @@ "$ref": "#/components/schemas/InterleavedContent" } }, - "additionalProperties": false + "additionalProperties": false, + "title": "RAGQueryResult" }, "QueryChunksRequest": { "type": "object", @@ -7032,7 +7576,8 @@ "required": [ "vector_db_id", "query" - ] + ], + "title": "QueryChunksRequest" }, "QueryChunksResponse": { "type": "object", @@ -7075,7 +7620,8 @@ "required": [ "content", "metadata" - ] + ], + "title": "Chunk" } }, "scores": { @@ -7089,7 +7635,8 @@ "required": [ "chunks", "scores" - ] + ], + "title": "QueryChunksResponse" }, "QueryCondition": { "type": "object", @@ -7128,7 +7675,8 @@ "key", "op", "value" - ] + ], + "title": "QueryCondition" }, "QueryConditionOp": { "type": "string", @@ -7137,7 +7685,8 @@ "ne", "gt", "lt" - ] + ], + "title": "QueryConditionOp" }, "QuerySpansResponse": { "type": "object", @@ -7152,7 +7701,8 @@ "additionalProperties": false, "required": [ "data" - ] + ], + "title": "QuerySpansResponse" }, "QueryTracesResponse": { "type": "object", @@ -7167,7 +7717,63 @@ "additionalProperties": false, "required": [ "data" - ] + ], + "title": "QueryTracesResponse" + }, + "RegisterBenchmarkRequest": { + "type": "object", + "properties": { + "benchmark_id": { + "type": "string" + }, + "dataset_id": { + "type": "string" + }, + "scoring_functions": { + "type": "array", + "items": { + "type": "string" + } + }, + "provider_benchmark_id": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "benchmark_id", + "dataset_id", + "scoring_functions" + ], + "title": "RegisterBenchmarkRequest" }, "RegisterDatasetRequest": { "type": "object", @@ -7221,61 +7827,8 @@ "dataset_id", "dataset_schema", "url" - ] - }, - "RegisterEvalTaskRequest": { - "type": "object", - "properties": { - "eval_task_id": { - "type": "string" - }, - "dataset_id": { - "type": "string" - }, - "scoring_functions": { - "type": "array", - "items": { - "type": "string" - } - }, - "provider_eval_task_id": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "metadata": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "eval_task_id", - "dataset_id", - "scoring_functions" - ] + ], + "title": "RegisterDatasetRequest" }, "RegisterModelRequest": { "type": "object", @@ -7321,7 +7874,8 @@ "additionalProperties": false, "required": [ "model_id" - ] + ], + "title": "RegisterModelRequest" }, "RegisterScoringFunctionRequest": { "type": "object", @@ -7350,7 +7904,8 @@ "scoring_fn_id", "description", "return_type" - ] + ], + "title": "RegisterScoringFunctionRequest" }, "RegisterShieldRequest": { "type": "object", @@ -7393,7 +7948,8 @@ "additionalProperties": false, "required": [ "shield_id" - ] + ], + "title": "RegisterShieldRequest" }, "RegisterToolGroupRequest": { "type": "object", @@ -7437,7 +7993,8 @@ "required": [ "toolgroup_id", "provider_id" - ] + ], + "title": "RegisterToolGroupRequest" }, "RegisterVectorDbRequest": { "type": "object", @@ -7462,31 +8019,21 @@ "required": [ "vector_db_id", "embedding_model" - ] + ], + "title": "RegisterVectorDbRequest" }, "RunEvalRequest": { "type": "object", "properties": { "task_config": { - "$ref": "#/components/schemas/EvalTaskConfig" + "$ref": "#/components/schemas/BenchmarkConfig" } }, "additionalProperties": false, "required": [ "task_config" - ] - }, - "Job": { - "type": "object", - "properties": { - "job_id": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "job_id" - ] + ], + "title": "RunEvalRequest" }, "RunShieldRequest": { "type": "object", @@ -7531,7 +8078,8 @@ "shield_id", "messages", "params" - ] + ], + "title": "RunShieldRequest" }, "RunShieldResponse": { "type": "object", @@ -7540,7 +8088,8 @@ "$ref": "#/components/schemas/SafetyViolation" } }, - "additionalProperties": false + "additionalProperties": false, + "title": "RunShieldResponse" }, "SaveSpansToDatasetRequest": { "type": "object", @@ -7569,7 +8118,8 @@ "attribute_filters", "attributes_to_save", "dataset_id" - ] + ], + "title": "SaveSpansToDatasetRequest" }, "ScoreRequest": { "type": "object", @@ -7620,7 +8170,8 @@ "required": [ "input_rows", "scoring_functions" - ] + ], + "title": "ScoreRequest" }, "ScoreResponse": { "type": "object", @@ -7635,7 +8186,8 @@ "additionalProperties": false, "required": [ "results" - ] + ], + "title": "ScoreResponse" }, "ScoreBatchRequest": { "type": "object", @@ -7665,7 +8217,8 @@ "dataset_id", "scoring_functions", "save_results_dataset" - ] + ], + "title": "ScoreBatchRequest" }, "ScoreBatchResponse": { "type": "object", @@ -7683,7 +8236,8 @@ "additionalProperties": false, "required": [ "results" - ] + ], + "title": "ScoreBatchResponse" }, "AlgorithmConfig": { "oneOf": [ @@ -7745,7 +8299,8 @@ "apply_lora_to_output", "rank", "alpha" - ] + ], + "title": "LoraFinetuningConfig" }, "QATFinetuningConfig": { "type": "object", @@ -7767,7 +8322,8 @@ "type", "quantizer_name", "group_size" - ] + ], + "title": "QATFinetuningConfig" }, "SupervisedFineTuneRequest": { "type": "object", @@ -7845,7 +8401,8 @@ "hyperparam_search_config", "logger_config", "model" - ] + ], + "title": "SupervisedFineTuneRequest" }, "SyntheticDataGenerateRequest": { "type": "object", @@ -7866,6 +8423,7 @@ "top_k_top_p", "sigmoid" ], + "title": "FilteringFunction", "description": "The type of filtering function." }, "model": { @@ -7876,7 +8434,8 @@ "required": [ "dialogs", "filtering_function" - ] + ], + "title": "SyntheticDataGenerateRequest" }, "SyntheticDataGenerationResponse": { "type": "object", @@ -7939,6 +8498,7 @@ "required": [ "synthetic_data" ], + "title": "SyntheticDataGenerationResponse", "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold." }, "VersionInfo": { @@ -7951,7 +8511,8 @@ "additionalProperties": false, "required": [ "version" - ] + ], + "title": "VersionInfo" } }, "responses": {} @@ -7970,6 +8531,9 @@ { "name": "BatchInference (Coming Soon)" }, + { + "name": "Benchmarks" + }, { "name": "DatasetIO" }, @@ -7979,9 +8543,6 @@ { "name": "Eval" }, - { - "name": "EvalTasks" - }, { "name": "Inference", "description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", @@ -8033,10 +8594,10 @@ "tags": [ "Agents", "BatchInference (Coming Soon)", + "Benchmarks", "DatasetIO", "Datasets", "Eval", - "EvalTasks", "Inference", "Inspect", "Models", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index a646d7e08..99300fedf 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -10,6 +10,175 @@ info: servers: - url: http://any-hosted-llama-stack.com paths: + /v1/eval/tasks/{task_id}/evaluations: + post: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/EvaluateResponse' + tags: + - Eval + description: '' + parameters: + - name: task_id + in: path + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/DeprecatedEvaluateRowsRequest' + required: true + deprecated: true + /v1/eval-tasks/{eval_task_id}: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/Benchmark' + - type: 'null' + tags: + - Benchmarks + description: '' + parameters: + - name: eval_task_id + in: path + required: true + schema: + type: string + deprecated: true + /v1/eval/tasks/{task_id}/jobs/{job_id}: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/JobStatus' + - type: 'null' + tags: + - Eval + description: '' + parameters: + - name: task_id + in: path + required: true + schema: + type: string + - name: job_id + in: path + required: true + schema: + type: string + deprecated: true + delete: + responses: + '200': + description: OK + tags: + - Eval + description: '' + parameters: + - name: task_id + in: path + required: true + schema: + type: string + - name: job_id + in: path + required: true + schema: + type: string + deprecated: true + /v1/eval/tasks/{task_id}/jobs/{job_id}/result: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/EvaluateResponse' + tags: + - Eval + description: '' + parameters: + - name: task_id + in: path + required: true + schema: + type: string + - name: job_id + in: path + required: true + schema: + type: string + deprecated: true + /v1/eval-tasks: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/ListBenchmarksResponse' + tags: + - Benchmarks + description: '' + parameters: [] + deprecated: true + post: + responses: + '200': + description: OK + tags: + - Benchmarks + description: '' + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/DeprecatedRegisterEvalTaskRequest' + required: true + deprecated: true + /v1/eval/tasks/{task_id}/jobs: + post: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/Job' + tags: + - Eval + description: '' + parameters: + - name: task_id + in: path + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/DeprecatedRunEvalRequest' + required: true + deprecated: true /v1/datasetio/rows: get: responses: @@ -322,7 +491,7 @@ paths: schema: $ref: '#/components/schemas/EmbeddingsRequest' required: true - /v1/eval/tasks/{task_id}/evaluations: + /v1/eval/benchmarks/{benchmark_id}/evaluations: post: responses: '200': @@ -335,7 +504,7 @@ paths: - Eval description: '' parameters: - - name: task_id + - name: benchmark_id in: path required: true schema: @@ -407,6 +576,26 @@ paths: required: true schema: type: string + /v1/eval/benchmarks/{benchmark_id}: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/Benchmark' + - type: 'null' + tags: + - Benchmarks + description: '' + parameters: + - name: benchmark_id + in: path + required: true + schema: + type: string /v1/datasets/{dataset_id}: get: responses: @@ -440,26 +629,6 @@ paths: required: true schema: type: string - /v1/eval-tasks/{eval_task_id}: - get: - responses: - '200': - description: OK - content: - application/json: - schema: - oneOf: - - $ref: '#/components/schemas/EvalTask' - - type: 'null' - tags: - - EvalTasks - description: '' - parameters: - - name: eval_task_id - in: path - required: true - schema: - type: string /v1/models/{model_id}: get: responses: @@ -802,7 +971,7 @@ paths: schema: $ref: '#/components/schemas/InvokeToolRequest' required: true - /v1/eval/tasks/{task_id}/jobs/{job_id}: + /v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}: get: responses: '200': @@ -817,7 +986,7 @@ paths: - Eval description: '' parameters: - - name: task_id + - name: benchmark_id in: path required: true schema: @@ -835,7 +1004,7 @@ paths: - Eval description: '' parameters: - - name: task_id + - name: benchmark_id in: path required: true schema: @@ -845,7 +1014,7 @@ paths: required: true schema: type: string - /v1/eval/tasks/{task_id}/jobs/{job_id}/result: + /v1/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result: get: responses: '200': @@ -858,16 +1027,43 @@ paths: - Eval description: '' parameters: + - name: benchmark_id + in: path + required: true + schema: + type: string - name: job_id in: path required: true schema: type: string - - name: task_id - in: path - required: true - schema: - type: string + /v1/eval/benchmarks: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/ListBenchmarksResponse' + tags: + - Benchmarks + description: '' + parameters: [] + post: + responses: + '200': + description: OK + tags: + - Benchmarks + description: '' + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterBenchmarkRequest' + required: true /v1/datasets: get: responses: @@ -895,33 +1091,6 @@ paths: schema: $ref: '#/components/schemas/RegisterDatasetRequest' required: true - /v1/eval-tasks: - get: - responses: - '200': - description: OK - content: - application/json: - schema: - $ref: '#/components/schemas/ListEvalTasksResponse' - tags: - - EvalTasks - description: '' - parameters: [] - post: - responses: - '200': - description: OK - tags: - - EvalTasks - description: '' - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterEvalTaskRequest' - required: true /v1/models: get: responses: @@ -1278,7 +1447,7 @@ paths: type: array items: type: string - /v1/eval/tasks/{task_id}/jobs: + /v1/eval/benchmarks/{benchmark_id}/jobs: post: responses: '200': @@ -1291,7 +1460,7 @@ paths: - Eval description: '' parameters: - - name: task_id + - name: benchmark_id in: path required: true schema: @@ -1429,65 +1598,157 @@ jsonSchemaDialect: >- https://json-schema.org/draft/2020-12/schema components: schemas: - AppendRowsRequest: + AgentCandidate: type: object properties: - dataset_id: + type: type: string - rows: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object + const: agent + default: agent + config: + $ref: '#/components/schemas/AgentConfig' additionalProperties: false required: - - dataset_id - - rows - CompletionMessage: + - type + - config + title: AgentCandidate + AgentConfig: type: object properties: - role: - type: string - const: assistant - default: assistant - description: >- - Must be "assistant" to identify this as the model's response - content: - $ref: '#/components/schemas/InterleavedContent' - description: The content of the model's response - stop_reason: + sampling_params: + $ref: '#/components/schemas/SamplingParams' + input_shields: + type: array + items: + type: string + output_shields: + type: array + items: + type: string + toolgroups: + type: array + items: + $ref: '#/components/schemas/AgentTool' + client_tools: + type: array + items: + $ref: '#/components/schemas/ToolDef' + tool_choice: type: string enum: - - end_of_turn - - end_of_message - - out_of_tokens + - auto + - required + - none + title: ToolChoice description: >- - Reason why the model stopped generating. Options are: - `StopReason.end_of_turn`: - The model finished generating the entire response. - `StopReason.end_of_message`: - The model finished generating but generated a partial response -- usually, - a tool call. The user may call the tool and continue the conversation - with the tool's response. - `StopReason.out_of_tokens`: The model ran - out of token budget. - tool_calls: - type: array - items: - $ref: '#/components/schemas/ToolCall' + Whether tool use is required or automatic. This is a hint to the model + which may not be followed. It depends on the Instruction Following capabilities + of the model. + deprecated: true + tool_prompt_format: + type: string + enum: + - json + - function_tag + - python_list + title: ToolPromptFormat description: >- - List of tool calls. Each tool call is a ToolCall object. + Prompt format for calling custom / zero shot tools. + deprecated: true + tool_config: + $ref: '#/components/schemas/ToolConfig' + max_infer_iters: + type: integer + default: 10 + model: + type: string + instructions: + type: string + enable_session_persistence: + type: boolean + default: false + response_format: + $ref: '#/components/schemas/ResponseFormat' additionalProperties: false required: - - role - - content - - stop_reason - description: >- - A message containing the model's (assistant) response in a chat conversation. + - model + - instructions + title: AgentConfig + AgentTool: + oneOf: + - type: string + - type: object + properties: + name: + type: string + args: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - name + - args + title: AgentToolGroupWithArgs + AggregationFunctionType: + type: string + enum: + - average + - median + - categorical_count + - accuracy + title: AggregationFunctionType + BasicScoringFnParams: + type: object + properties: + type: + type: string + const: basic + default: basic + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + additionalProperties: false + required: + - type + title: BasicScoringFnParams + BenchmarkConfig: + type: object + properties: + type: + type: string + const: benchmark + default: benchmark + eval_candidate: + $ref: '#/components/schemas/EvalCandidate' + scoring_params: + type: object + additionalProperties: + $ref: '#/components/schemas/ScoringFnParams' + num_examples: + type: integer + additionalProperties: false + required: + - type + - eval_candidate + - scoring_params + title: BenchmarkConfig + EvalCandidate: + oneOf: + - $ref: '#/components/schemas/ModelCandidate' + - $ref: '#/components/schemas/AgentCandidate' + discriminator: + propertyName: type + mapping: + model: '#/components/schemas/ModelCandidate' + agent: '#/components/schemas/AgentCandidate' GrammarResponseFormat: type: object properties: @@ -1513,6 +1774,7 @@ components: required: - type - bnf + title: GrammarResponseFormat description: >- Configuration for grammar-guided response generation. GreedySamplingStrategy: @@ -1525,6 +1787,7 @@ components: additionalProperties: false required: - type + title: GreedySamplingStrategy ImageContentItem: type: object properties: @@ -1553,6 +1816,7 @@ components: required: - type - image + title: ImageContentItem description: A image content item InterleavedContent: oneOf: @@ -1596,21 +1860,71 @@ components: required: - type - json_schema + title: JsonSchemaResponseFormat description: >- Configuration for JSON schema-guided response generation. - Message: - oneOf: - - $ref: '#/components/schemas/UserMessage' - - $ref: '#/components/schemas/SystemMessage' - - $ref: '#/components/schemas/ToolResponseMessage' - - $ref: '#/components/schemas/CompletionMessage' - discriminator: - propertyName: role - mapping: - user: '#/components/schemas/UserMessage' - system: '#/components/schemas/SystemMessage' - tool: '#/components/schemas/ToolResponseMessage' - assistant: '#/components/schemas/CompletionMessage' + LLMAsJudgeScoringFnParams: + type: object + properties: + type: + type: string + const: llm_as_judge + default: llm_as_judge + judge_model: + type: string + prompt_template: + type: string + judge_score_regexes: + type: array + items: + type: string + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + additionalProperties: false + required: + - type + - judge_model + title: LLMAsJudgeScoringFnParams + ModelCandidate: + type: object + properties: + type: + type: string + const: model + default: model + model: + type: string + sampling_params: + $ref: '#/components/schemas/SamplingParams' + system_message: + $ref: '#/components/schemas/SystemMessage' + additionalProperties: false + required: + - type + - model + - sampling_params + title: ModelCandidate + RegexParserScoringFnParams: + type: object + properties: + type: + type: string + const: regex_parser + default: regex_parser + parsing_regexes: + type: array + items: + type: string + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + additionalProperties: false + required: + - type + title: RegexParserScoringFnParams ResponseFormat: oneOf: - $ref: '#/components/schemas/JsonSchemaResponseFormat' @@ -1634,6 +1948,7 @@ components: additionalProperties: false required: - strategy + title: SamplingParams SamplingStrategy: oneOf: - $ref: '#/components/schemas/GreedySamplingStrategy' @@ -1645,6 +1960,17 @@ components: greedy: '#/components/schemas/GreedySamplingStrategy' top_p: '#/components/schemas/TopPSamplingStrategy' top_k: '#/components/schemas/TopKSamplingStrategy' + ScoringFnParams: + oneOf: + - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' + - $ref: '#/components/schemas/RegexParserScoringFnParams' + - $ref: '#/components/schemas/BasicScoringFnParams' + discriminator: + propertyName: type + mapping: + llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams' + regex_parser: '#/components/schemas/RegexParserScoringFnParams' + basic: '#/components/schemas/BasicScoringFnParams' SystemMessage: type: object properties: @@ -1664,6 +1990,7 @@ components: required: - role - content + title: SystemMessage description: >- A system message providing instructions or context to the model. TextContentItem: @@ -1682,7 +2009,409 @@ components: required: - type - text + title: TextContentItem description: A text content item + ToolConfig: + type: object + properties: + tool_choice: + oneOf: + - type: string + enum: + - auto + - required + - none + title: ToolChoice + description: >- + Whether tool use is required or automatic. This is a hint to the model + which may not be followed. It depends on the Instruction Following + capabilities of the model. + - type: string + default: auto + description: >- + (Optional) Whether tool use is automatic, required, or none. Can also + specify a tool name to use a specific tool. Defaults to ToolChoice.auto. + tool_prompt_format: + type: string + enum: + - json + - function_tag + - python_list + description: >- + (Optional) Instructs the model how to format tool calls. By default, Llama + Stack will attempt to use a format that is best adapted to the model. + - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. + - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a + tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python + syntax -- a list of function calls. + system_message_behavior: + type: string + enum: + - append + - replace + description: >- + (Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: + Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: + Replaces the default system prompt with the provided system message. The + system message can include the string '{{function_definitions}}' to indicate + where the function definitions should be inserted. + default: append + additionalProperties: false + title: ToolConfig + description: Configuration for tool use. + ToolDef: + type: object + properties: + name: + type: string + description: + type: string + parameters: + type: array + items: + $ref: '#/components/schemas/ToolParameter' + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - name + title: ToolDef + ToolParameter: + type: object + properties: + name: + type: string + parameter_type: + type: string + description: + type: string + required: + type: boolean + default: true + default: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - name + - parameter_type + - description + - required + title: ToolParameter + TopKSamplingStrategy: + type: object + properties: + type: + type: string + const: top_k + default: top_k + top_k: + type: integer + additionalProperties: false + required: + - type + - top_k + title: TopKSamplingStrategy + TopPSamplingStrategy: + type: object + properties: + type: + type: string + const: top_p + default: top_p + temperature: + type: number + top_p: + type: number + default: 0.95 + additionalProperties: false + required: + - type + title: TopPSamplingStrategy + URL: + type: object + properties: + uri: + type: string + additionalProperties: false + required: + - uri + title: URL + DeprecatedEvaluateRowsRequest: + type: object + properties: + input_rows: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + scoring_functions: + type: array + items: + type: string + task_config: + $ref: '#/components/schemas/BenchmarkConfig' + additionalProperties: false + required: + - input_rows + - scoring_functions + - task_config + title: DeprecatedEvaluateRowsRequest + EvaluateResponse: + type: object + properties: + generations: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + scores: + type: object + additionalProperties: + $ref: '#/components/schemas/ScoringResult' + additionalProperties: false + required: + - generations + - scores + title: EvaluateResponse + ScoringResult: + type: object + properties: + score_rows: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + aggregated_results: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - score_rows + - aggregated_results + title: ScoringResult + Benchmark: + type: object + properties: + identifier: + type: string + provider_resource_id: + type: string + provider_id: + type: string + type: + type: string + const: benchmark + default: benchmark + dataset_id: + type: string + scoring_functions: + type: array + items: + type: string + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - identifier + - provider_resource_id + - provider_id + - type + - dataset_id + - scoring_functions + - metadata + title: Benchmark + JobStatus: + type: string + enum: + - completed + - in_progress + - failed + - scheduled + title: JobStatus + ListBenchmarksResponse: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/Benchmark' + additionalProperties: false + required: + - data + title: ListBenchmarksResponse + DeprecatedRegisterEvalTaskRequest: + type: object + properties: + eval_task_id: + type: string + dataset_id: + type: string + scoring_functions: + type: array + items: + type: string + provider_benchmark_id: + type: string + provider_id: + type: string + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - eval_task_id + - dataset_id + - scoring_functions + title: DeprecatedRegisterEvalTaskRequest + DeprecatedRunEvalRequest: + type: object + properties: + task_config: + $ref: '#/components/schemas/BenchmarkConfig' + additionalProperties: false + required: + - task_config + title: DeprecatedRunEvalRequest + Job: + type: object + properties: + job_id: + type: string + additionalProperties: false + required: + - job_id + title: Job + AppendRowsRequest: + type: object + properties: + dataset_id: + type: string + rows: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - dataset_id + - rows + title: AppendRowsRequest + CompletionMessage: + type: object + properties: + role: + type: string + const: assistant + default: assistant + description: >- + Must be "assistant" to identify this as the model's response + content: + $ref: '#/components/schemas/InterleavedContent' + description: The content of the model's response + stop_reason: + type: string + enum: + - end_of_turn + - end_of_message + - out_of_tokens + description: >- + Reason why the model stopped generating. Options are: - `StopReason.end_of_turn`: + The model finished generating the entire response. - `StopReason.end_of_message`: + The model finished generating but generated a partial response -- usually, + a tool call. The user may call the tool and continue the conversation + with the tool's response. - `StopReason.out_of_tokens`: The model ran + out of token budget. + tool_calls: + type: array + items: + $ref: '#/components/schemas/ToolCall' + description: >- + List of tool calls. Each tool call is a ToolCall object. + additionalProperties: false + required: + - role + - content + - stop_reason + title: CompletionMessage + description: >- + A message containing the model's (assistant) response in a chat conversation. + Message: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/SystemMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + - $ref: '#/components/schemas/CompletionMessage' + discriminator: + propertyName: role + mapping: + user: '#/components/schemas/UserMessage' + system: '#/components/schemas/SystemMessage' + tool: '#/components/schemas/ToolResponseMessage' + assistant: '#/components/schemas/CompletionMessage' ToolCall: type: object properties: @@ -1696,6 +2425,7 @@ components: - wolfram_alpha - photogen - code_interpreter + title: BuiltinTool - type: string arguments: type: object @@ -1727,6 +2457,7 @@ components: - call_id - tool_name - arguments + title: ToolCall ToolDefinition: type: object properties: @@ -1738,6 +2469,7 @@ components: - wolfram_alpha - photogen - code_interpreter + title: BuiltinTool - type: string description: type: string @@ -1748,6 +2480,7 @@ components: additionalProperties: false required: - tool_name + title: ToolDefinition ToolParamDefinition: type: object properties: @@ -1769,6 +2502,7 @@ components: additionalProperties: false required: - param_type + title: ToolParamDefinition ToolResponseMessage: type: object properties: @@ -1790,6 +2524,7 @@ components: - wolfram_alpha - photogen - code_interpreter + title: BuiltinTool - type: string description: Name of the tool that was called content: @@ -1801,44 +2536,9 @@ components: - call_id - tool_name - content + title: ToolResponseMessage description: >- A message representing the result of a tool invocation. - TopKSamplingStrategy: - type: object - properties: - type: - type: string - const: top_k - default: top_k - top_k: - type: integer - additionalProperties: false - required: - - type - - top_k - TopPSamplingStrategy: - type: object - properties: - type: - type: string - const: top_p - default: top_p - temperature: - type: number - top_p: - type: number - default: 0.95 - additionalProperties: false - required: - - type - URL: - type: object - properties: - uri: - type: string - additionalProperties: false - required: - - uri UserMessage: type: object properties: @@ -1861,6 +2561,7 @@ components: required: - role - content + title: UserMessage description: >- A message from the user in a chat conversation. BatchChatCompletionRequest: @@ -1885,6 +2586,8 @@ components: enum: - auto - required + - none + title: ToolChoice description: >- Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities @@ -1895,6 +2598,7 @@ components: - json - function_tag - python_list + title: ToolPromptFormat description: >- Prompt format for calling custom / zero shot tools. response_format: @@ -1908,10 +2612,12 @@ components: description: >- How many tokens (for each position) to return log probabilities for. additionalProperties: false + title: LogProbConfig additionalProperties: false required: - model - messages_batch + title: BatchChatCompletionRequest BatchChatCompletionResponse: type: object properties: @@ -1922,6 +2628,7 @@ components: additionalProperties: false required: - batch + title: BatchChatCompletionResponse ChatCompletionResponse: type: object properties: @@ -1941,6 +2648,7 @@ components: additionalProperties: false required: - completion_message + title: ChatCompletionResponse description: Response from a chat completion request. MetricEvent: type: object @@ -1982,6 +2690,7 @@ components: - metric - value - unit + title: MetricEvent TokenLogProbs: type: object properties: @@ -1994,6 +2703,7 @@ components: additionalProperties: false required: - logprobs_by_token + title: TokenLogProbs description: Log probabilities for generated tokens. BatchCompletionRequest: type: object @@ -2017,10 +2727,12 @@ components: description: >- How many tokens (for each position) to return log probabilities for. additionalProperties: false + title: LogProbConfig additionalProperties: false required: - model - content_batch + title: BatchCompletionRequest BatchCompletionResponse: type: object properties: @@ -2031,6 +2743,7 @@ components: additionalProperties: false required: - batch + title: BatchCompletionResponse CompletionResponse: type: object properties: @@ -2054,6 +2767,7 @@ components: required: - content - stop_reason + title: CompletionResponse description: Response from a completion request. CancelTrainingJobRequest: type: object @@ -2063,46 +2777,7 @@ components: additionalProperties: false required: - job_uuid - ToolConfig: - type: object - properties: - tool_choice: - type: string - enum: - - auto - - required - description: >- - (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. - default: auto - tool_prompt_format: - type: string - enum: - - json - - function_tag - - python_list - description: >- - (Optional) Instructs the model how to format tool calls. By default, Llama - Stack will attempt to use a format that is best adapted to the model. - - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a - tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python - syntax -- a list of function calls. - system_message_behavior: - type: string - enum: - - append - - replace - description: >- - (Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: - Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: - Replaces the default system prompt with the provided system message. The - system message can include the string '{{function_definitions}}' to indicate - where the function definitions should be inserted. - default: append - additionalProperties: false - required: - - system_message_behavior - description: Configuration for tool use. + title: CancelTrainingJobRequest ChatCompletionRequest: type: object properties: @@ -2131,6 +2806,7 @@ components: enum: - auto - required + - none description: >- (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. .. deprecated:: Use tool_config instead. @@ -2179,6 +2855,7 @@ components: required: - model_id - messages + title: ChatCompletionRequest ChatCompletionResponseEvent: type: object properties: @@ -2212,6 +2889,7 @@ components: required: - event_type - delta + title: ChatCompletionResponseEvent description: >- An event during chat completion generation. ChatCompletionResponseStreamChunk: @@ -2227,6 +2905,7 @@ components: additionalProperties: false required: - event + title: ChatCompletionResponseStreamChunk description: >- A chunk of a streamed chat completion response. ContentDelta: @@ -2254,6 +2933,7 @@ components: required: - type - image + title: ImageDelta TextDelta: type: object properties: @@ -2267,6 +2947,7 @@ components: required: - type - text + title: TextDelta ToolCallDelta: type: object properties: @@ -2285,11 +2966,13 @@ components: - in_progress - failed - succeeded + title: ToolCallParseStatus additionalProperties: false required: - type - tool_call - parse_status + title: ToolCallDelta CompletionRequest: type: object properties: @@ -2330,6 +3013,7 @@ components: required: - model_id - content + title: CompletionRequest CompletionResponseStreamChunk: type: object properties: @@ -2354,135 +3038,9 @@ components: additionalProperties: false required: - delta + title: CompletionResponseStreamChunk description: >- A chunk of a streamed completion response. - AgentConfig: - type: object - properties: - sampling_params: - $ref: '#/components/schemas/SamplingParams' - input_shields: - type: array - items: - type: string - output_shields: - type: array - items: - type: string - toolgroups: - type: array - items: - $ref: '#/components/schemas/AgentTool' - client_tools: - type: array - items: - $ref: '#/components/schemas/ToolDef' - tool_choice: - type: string - enum: - - auto - - required - description: >- - Whether tool use is required or automatic. This is a hint to the model - which may not be followed. It depends on the Instruction Following capabilities - of the model. - tool_prompt_format: - type: string - enum: - - json - - function_tag - - python_list - description: >- - Prompt format for calling custom / zero shot tools. - tool_config: - $ref: '#/components/schemas/ToolConfig' - max_infer_iters: - type: integer - default: 10 - model: - type: string - instructions: - type: string - enable_session_persistence: - type: boolean - response_format: - $ref: '#/components/schemas/ResponseFormat' - additionalProperties: false - required: - - model - - instructions - - enable_session_persistence - AgentTool: - oneOf: - - type: string - - type: object - properties: - name: - type: string - args: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - additionalProperties: false - required: - - name - - args - ToolDef: - type: object - properties: - name: - type: string - description: - type: string - parameters: - type: array - items: - $ref: '#/components/schemas/ToolParameter' - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - additionalProperties: false - required: - - name - ToolParameter: - type: object - properties: - name: - type: string - parameter_type: - type: string - description: - type: string - required: - type: boolean - default: true - default: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - additionalProperties: false - required: - - name - - parameter_type - - description - - required CreateAgentRequest: type: object properties: @@ -2491,6 +3049,7 @@ components: additionalProperties: false required: - agent_config + title: CreateAgentRequest AgentCreateResponse: type: object properties: @@ -2499,6 +3058,7 @@ components: additionalProperties: false required: - agent_id + title: AgentCreateResponse CreateAgentSessionRequest: type: object properties: @@ -2507,6 +3067,7 @@ components: additionalProperties: false required: - session_name + title: CreateAgentSessionRequest AgentSessionCreateResponse: type: object properties: @@ -2515,6 +3076,7 @@ components: additionalProperties: false required: - session_id + title: AgentSessionCreateResponse CreateAgentTurnRequest: type: object properties: @@ -2545,6 +3107,7 @@ components: required: - content - mime_type + title: Document toolgroups: type: array items: @@ -2554,6 +3117,7 @@ components: additionalProperties: false required: - messages + title: CreateAgentTurnRequest InferenceStep: type: object properties: @@ -2579,6 +3143,7 @@ components: - step_id - step_type - model_response + title: InferenceStep MemoryRetrievalStep: type: object properties: @@ -2607,6 +3172,7 @@ components: - step_type - vector_db_ids - inserted_context + title: MemoryRetrievalStep SafetyViolation: type: object properties: @@ -2628,6 +3194,7 @@ components: required: - violation_level - metadata + title: SafetyViolation ShieldCallStep: type: object properties: @@ -2652,6 +3219,7 @@ components: - turn_id - step_id - step_type + title: ShieldCallStep ToolExecutionStep: type: object properties: @@ -2684,6 +3252,7 @@ components: - step_type - tool_calls - tool_responses + title: ToolExecutionStep ToolResponse: type: object properties: @@ -2697,6 +3266,7 @@ components: - wolfram_alpha - photogen - code_interpreter + title: BuiltinTool - type: string content: $ref: '#/components/schemas/InterleavedContent' @@ -2705,6 +3275,7 @@ components: - call_id - tool_name - content + title: ToolResponse Turn: type: object properties: @@ -2754,6 +3325,7 @@ components: required: - content - mime_type + title: Attachment started_at: type: string format: date-time @@ -2768,6 +3340,7 @@ components: - steps - output_message - started_at + title: Turn description: >- A single turn in an interaction with an Agentic System. ViolationLevel: @@ -2776,6 +3349,7 @@ components: - info - warn - error + title: ViolationLevel AgentTurnResponseEvent: type: object properties: @@ -2784,6 +3358,7 @@ components: additionalProperties: false required: - payload + title: AgentTurnResponseEvent AgentTurnResponseEventPayload: oneOf: - $ref: '#/components/schemas/AgentTurnResponseStepStartPayload' @@ -2813,6 +3388,7 @@ components: - tool_execution - shield_call - memory_retrieval + title: StepType step_id: type: string step_details: @@ -2834,6 +3410,7 @@ components: - step_type - step_id - step_details + title: AgentTurnResponseStepCompletePayload AgentTurnResponseStepProgressPayload: type: object properties: @@ -2848,6 +3425,7 @@ components: - tool_execution - shield_call - memory_retrieval + title: StepType step_id: type: string delta: @@ -2858,6 +3436,7 @@ components: - step_type - step_id - delta + title: AgentTurnResponseStepProgressPayload AgentTurnResponseStepStartPayload: type: object properties: @@ -2872,6 +3451,7 @@ components: - tool_execution - shield_call - memory_retrieval + title: StepType step_id: type: string metadata: @@ -2889,6 +3469,7 @@ components: - event_type - step_type - step_id + title: AgentTurnResponseStepStartPayload AgentTurnResponseStreamChunk: type: object properties: @@ -2897,6 +3478,7 @@ components: additionalProperties: false required: - event + title: AgentTurnResponseStreamChunk description: streamed agent turn completion response. AgentTurnResponseTurnCompletePayload: type: object @@ -2911,6 +3493,7 @@ components: required: - event_type - turn + title: AgentTurnResponseTurnCompletePayload AgentTurnResponseTurnStartPayload: type: object properties: @@ -2924,6 +3507,7 @@ components: required: - event_type - turn_id + title: AgentTurnResponseTurnStartPayload EmbeddingsRequest: type: object properties: @@ -2944,6 +3528,7 @@ components: required: - model_id - contents + title: EmbeddingsRequest EmbeddingsResponse: type: object properties: @@ -2960,165 +3545,9 @@ components: additionalProperties: false required: - embeddings + title: EmbeddingsResponse description: >- Response containing generated embeddings. - AgentCandidate: - type: object - properties: - type: - type: string - const: agent - default: agent - config: - $ref: '#/components/schemas/AgentConfig' - additionalProperties: false - required: - - type - - config - AggregationFunctionType: - type: string - enum: - - average - - median - - categorical_count - - accuracy - AppEvalTaskConfig: - type: object - properties: - type: - type: string - const: app - default: app - eval_candidate: - $ref: '#/components/schemas/EvalCandidate' - scoring_params: - type: object - additionalProperties: - $ref: '#/components/schemas/ScoringFnParams' - num_examples: - type: integer - additionalProperties: false - required: - - type - - eval_candidate - - scoring_params - BasicScoringFnParams: - type: object - properties: - type: - type: string - const: basic - default: basic - aggregation_functions: - type: array - items: - $ref: '#/components/schemas/AggregationFunctionType' - additionalProperties: false - required: - - type - BenchmarkEvalTaskConfig: - type: object - properties: - type: - type: string - const: benchmark - default: benchmark - eval_candidate: - $ref: '#/components/schemas/EvalCandidate' - num_examples: - type: integer - additionalProperties: false - required: - - type - - eval_candidate - EvalCandidate: - oneOf: - - $ref: '#/components/schemas/ModelCandidate' - - $ref: '#/components/schemas/AgentCandidate' - discriminator: - propertyName: type - mapping: - model: '#/components/schemas/ModelCandidate' - agent: '#/components/schemas/AgentCandidate' - EvalTaskConfig: - oneOf: - - $ref: '#/components/schemas/BenchmarkEvalTaskConfig' - - $ref: '#/components/schemas/AppEvalTaskConfig' - discriminator: - propertyName: type - mapping: - benchmark: '#/components/schemas/BenchmarkEvalTaskConfig' - app: '#/components/schemas/AppEvalTaskConfig' - LLMAsJudgeScoringFnParams: - type: object - properties: - type: - type: string - const: llm_as_judge - default: llm_as_judge - judge_model: - type: string - prompt_template: - type: string - judge_score_regexes: - type: array - items: - type: string - aggregation_functions: - type: array - items: - $ref: '#/components/schemas/AggregationFunctionType' - additionalProperties: false - required: - - type - - judge_model - ModelCandidate: - type: object - properties: - type: - type: string - const: model - default: model - model: - type: string - sampling_params: - $ref: '#/components/schemas/SamplingParams' - system_message: - $ref: '#/components/schemas/SystemMessage' - additionalProperties: false - required: - - type - - model - - sampling_params - RegexParserScoringFnParams: - type: object - properties: - type: - type: string - const: regex_parser - default: regex_parser - parsing_regexes: - type: array - items: - type: string - aggregation_functions: - type: array - items: - $ref: '#/components/schemas/AggregationFunctionType' - additionalProperties: false - required: - - type - ScoringFnParams: - oneOf: - - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - - $ref: '#/components/schemas/RegexParserScoringFnParams' - - $ref: '#/components/schemas/BasicScoringFnParams' - discriminator: - propertyName: type - mapping: - llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams' - regex_parser: '#/components/schemas/RegexParserScoringFnParams' - basic: '#/components/schemas/BasicScoringFnParams' EvaluateRowsRequest: type: object properties: @@ -3139,64 +3568,13 @@ components: items: type: string task_config: - $ref: '#/components/schemas/EvalTaskConfig' + $ref: '#/components/schemas/BenchmarkConfig' additionalProperties: false required: - input_rows - scoring_functions - task_config - EvaluateResponse: - type: object - properties: - generations: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - scores: - type: object - additionalProperties: - $ref: '#/components/schemas/ScoringResult' - additionalProperties: false - required: - - generations - - scores - ScoringResult: - type: object - properties: - score_rows: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - aggregated_results: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - additionalProperties: false - required: - - score_rows - - aggregated_results + title: EvaluateRowsRequest Session: type: object properties: @@ -3217,6 +3595,7 @@ components: - session_name - turns - started_at + title: Session description: >- A single session of an interaction with an Agentic System. AgentStepResponse: @@ -3238,6 +3617,7 @@ components: additionalProperties: false required: - step + title: AgentStepResponse AgentTurnInputType: type: object properties: @@ -3248,6 +3628,7 @@ components: additionalProperties: false required: - type + title: AgentTurnInputType ArrayType: type: object properties: @@ -3258,6 +3639,7 @@ components: additionalProperties: false required: - type + title: ArrayType BooleanType: type: object properties: @@ -3268,6 +3650,7 @@ components: additionalProperties: false required: - type + title: BooleanType ChatCompletionInputType: type: object properties: @@ -3278,6 +3661,7 @@ components: additionalProperties: false required: - type + title: ChatCompletionInputType CompletionInputType: type: object properties: @@ -3288,6 +3672,7 @@ components: additionalProperties: false required: - type + title: CompletionInputType Dataset: type: object properties: @@ -3326,6 +3711,7 @@ components: - dataset_schema - url - metadata + title: Dataset JsonType: type: object properties: @@ -3336,6 +3722,7 @@ components: additionalProperties: false required: - type + title: JsonType NumberType: type: object properties: @@ -3346,6 +3733,7 @@ components: additionalProperties: false required: - type + title: NumberType ObjectType: type: object properties: @@ -3356,6 +3744,7 @@ components: additionalProperties: false required: - type + title: ObjectType ParamType: oneOf: - $ref: '#/components/schemas/StringType' @@ -3391,6 +3780,7 @@ components: additionalProperties: false required: - type + title: StringType UnionType: type: object properties: @@ -3401,44 +3791,7 @@ components: additionalProperties: false required: - type - EvalTask: - type: object - properties: - identifier: - type: string - provider_resource_id: - type: string - provider_id: - type: string - type: - type: string - const: eval_task - default: eval_task - dataset_id: - type: string - scoring_functions: - type: array - items: - type: string - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - additionalProperties: false - required: - - identifier - - provider_resource_id - - provider_id - - type - - dataset_id - - scoring_functions - - metadata + title: UnionType Model: type: object properties: @@ -3473,11 +3826,13 @@ components: - type - metadata - model_type + title: Model ModelType: type: string enum: - llm - embedding + title: ModelType PaginatedRowsResult: type: object properties: @@ -3501,6 +3856,7 @@ components: required: - rows - total_count + title: PaginatedRowsResult ScoringFn: type: object properties: @@ -3538,6 +3894,7 @@ components: - type - metadata - return_type + title: ScoringFn Shield: type: object properties: @@ -3567,6 +3924,7 @@ components: - provider_resource_id - provider_id - type + title: Shield description: >- A safety shield resource that can be used to check content Span: @@ -3602,11 +3960,13 @@ components: - trace_id - name - start_time + title: Span SpanStatus: type: string enum: - ok - error + title: SpanStatus SpanWithStatus: type: object properties: @@ -3642,6 +4002,7 @@ components: - trace_id - name - start_time + title: SpanWithStatus QuerySpanTreeResponse: type: object properties: @@ -3652,6 +4013,7 @@ components: additionalProperties: false required: - data + title: QuerySpanTreeResponse Tool: type: object properties: @@ -3695,12 +4057,14 @@ components: - tool_host - description - parameters + title: Tool ToolHost: type: string enum: - distribution - client - model_context_protocol + title: ToolHost ToolGroup: type: object properties: @@ -3732,6 +4096,7 @@ components: - provider_resource_id - provider_id - type + title: ToolGroup Trace: type: object properties: @@ -3750,8 +4115,10 @@ components: - trace_id - root_span_id - start_time + title: Trace Checkpoint: description: Checkpoint created during training runs + title: Checkpoint PostTrainingJobArtifactsResponse: type: object properties: @@ -3765,14 +4132,8 @@ components: required: - job_uuid - checkpoints + title: PostTrainingJobArtifactsResponse description: Artifacts of a finetuning job. - JobStatus: - type: string - enum: - - completed - - in_progress - - failed - - scheduled PostTrainingJobStatusResponse: type: object properties: @@ -3808,6 +4169,7 @@ components: - job_uuid - status - checkpoints + title: PostTrainingJobStatusResponse description: Status of a finetuning job. ListPostTrainingJobsResponse: type: object @@ -3822,9 +4184,11 @@ components: additionalProperties: false required: - job_uuid + title: PostTrainingJob additionalProperties: false required: - data + title: ListPostTrainingJobsResponse VectorDB: type: object properties: @@ -3850,6 +4214,7 @@ components: - type - embedding_model - embedding_dimension + title: VectorDB HealthInfo: type: object properties: @@ -3858,6 +4223,7 @@ components: additionalProperties: false required: - status + title: HealthInfo RAGDocument: type: object properties: @@ -3888,6 +4254,7 @@ components: - document_id - content - metadata + title: RAGDocument InsertRequest: type: object properties: @@ -3904,6 +4271,7 @@ components: - documents - vector_db_id - chunk_size_in_tokens + title: InsertRequest InsertChunksRequest: type: object properties: @@ -3930,12 +4298,14 @@ components: required: - content - metadata + title: Chunk ttl_seconds: type: integer additionalProperties: false required: - vector_db_id - chunks + title: InsertChunksRequest InvokeToolRequest: type: object properties: @@ -3955,6 +4325,7 @@ components: required: - tool_name - kwargs + title: InvokeToolRequest ToolInvocationResult: type: object properties: @@ -3967,6 +4338,7 @@ components: additionalProperties: false required: - content + title: ToolInvocationResult ListDatasetsResponse: type: object properties: @@ -3977,16 +4349,7 @@ components: additionalProperties: false required: - data - ListEvalTasksResponse: - type: object - properties: - data: - type: array - items: - $ref: '#/components/schemas/EvalTask' - additionalProperties: false - required: - - data + title: ListDatasetsResponse ListModelsResponse: type: object properties: @@ -3997,6 +4360,7 @@ components: additionalProperties: false required: - data + title: ListModelsResponse ProviderInfo: type: object properties: @@ -4011,6 +4375,7 @@ components: - api - provider_id - provider_type + title: ProviderInfo ListProvidersResponse: type: object properties: @@ -4021,6 +4386,7 @@ components: additionalProperties: false required: - data + title: ListProvidersResponse RouteInfo: type: object properties: @@ -4037,6 +4403,7 @@ components: - route - method - provider_types + title: RouteInfo ListRoutesResponse: type: object properties: @@ -4047,6 +4414,7 @@ components: additionalProperties: false required: - data + title: ListRoutesResponse ListScoringFunctionsResponse: type: object properties: @@ -4057,6 +4425,7 @@ components: additionalProperties: false required: - data + title: ListScoringFunctionsResponse ListShieldsResponse: type: object properties: @@ -4067,6 +4436,7 @@ components: additionalProperties: false required: - data + title: ListShieldsResponse ListToolGroupsResponse: type: object properties: @@ -4077,6 +4447,7 @@ components: additionalProperties: false required: - data + title: ListToolGroupsResponse ListToolsResponse: type: object properties: @@ -4087,6 +4458,7 @@ components: additionalProperties: false required: - data + title: ListToolsResponse ListVectorDBsResponse: type: object properties: @@ -4097,6 +4469,7 @@ components: additionalProperties: false required: - data + title: ListVectorDBsResponse Event: oneOf: - $ref: '#/components/schemas/UnstructuredLogEvent' @@ -4117,6 +4490,7 @@ components: - warn - error - critical + title: LogSeverity SpanEndPayload: type: object properties: @@ -4130,6 +4504,7 @@ components: required: - type - status + title: SpanEndPayload SpanStartPayload: type: object properties: @@ -4145,6 +4520,7 @@ components: required: - type - name + title: SpanStartPayload StructuredLogEvent: type: object properties: @@ -4177,6 +4553,7 @@ components: - timestamp - type - payload + title: StructuredLogEvent StructuredLogPayload: oneOf: - $ref: '#/components/schemas/SpanStartPayload' @@ -4221,6 +4598,7 @@ components: - type - message - severity + title: UnstructuredLogEvent LogEventRequest: type: object properties: @@ -4232,6 +4610,7 @@ components: required: - event - ttl_seconds + title: LogEventRequest DPOAlignmentConfig: type: object properties: @@ -4249,6 +4628,7 @@ components: - reward_clip - epsilon - gamma + title: DPOAlignmentConfig DataConfig: type: object properties: @@ -4274,11 +4654,13 @@ components: - batch_size - shuffle - data_format + title: DataConfig DatasetFormat: type: string enum: - instruct - dialog + title: DatasetFormat EfficiencyConfig: type: object properties: @@ -4295,6 +4677,7 @@ components: type: boolean default: false additionalProperties: false + title: EfficiencyConfig OptimizerConfig: type: object properties: @@ -4312,12 +4695,14 @@ components: - lr - weight_decay - num_warmup_steps + title: OptimizerConfig OptimizerType: type: string enum: - adam - adamw - sgd + title: OptimizerType TrainingConfig: type: object properties: @@ -4346,6 +4731,7 @@ components: - max_validation_steps - data_config - optimizer_config + title: TrainingConfig PreferenceOptimizeRequest: type: object properties: @@ -4385,6 +4771,7 @@ components: - training_config - hyperparam_search_config - logger_config + title: PreferenceOptimizeRequest PostTrainingJob: type: object properties: @@ -4393,6 +4780,7 @@ components: additionalProperties: false required: - job_uuid + title: PostTrainingJob DefaultRAGQueryGeneratorConfig: type: object properties: @@ -4407,6 +4795,7 @@ components: required: - type - separator + title: DefaultRAGQueryGeneratorConfig LLMRAGQueryGeneratorConfig: type: object properties: @@ -4423,6 +4812,7 @@ components: - type - model - template + title: LLMRAGQueryGeneratorConfig RAGQueryConfig: type: object properties: @@ -4439,6 +4829,7 @@ components: - query_generator_config - max_tokens_in_context - max_chunks + title: RAGQueryConfig RAGQueryGeneratorConfig: oneOf: - $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig' @@ -4463,12 +4854,14 @@ components: required: - content - vector_db_ids + title: QueryRequest RAGQueryResult: type: object properties: content: $ref: '#/components/schemas/InterleavedContent' additionalProperties: false + title: RAGQueryResult QueryChunksRequest: type: object properties: @@ -4490,6 +4883,7 @@ components: required: - vector_db_id - query + title: QueryChunksRequest QueryChunksResponse: type: object properties: @@ -4514,6 +4908,7 @@ components: required: - content - metadata + title: Chunk scores: type: array items: @@ -4522,6 +4917,7 @@ components: required: - chunks - scores + title: QueryChunksResponse QueryCondition: type: object properties: @@ -4542,6 +4938,7 @@ components: - key - op - value + title: QueryCondition QueryConditionOp: type: string enum: @@ -4549,6 +4946,7 @@ components: - ne - gt - lt + title: QueryConditionOp QuerySpansResponse: type: object properties: @@ -4559,6 +4957,7 @@ components: additionalProperties: false required: - data + title: QuerySpansResponse QueryTracesResponse: type: object properties: @@ -4569,6 +4968,38 @@ components: additionalProperties: false required: - data + title: QueryTracesResponse + RegisterBenchmarkRequest: + type: object + properties: + benchmark_id: + type: string + dataset_id: + type: string + scoring_functions: + type: array + items: + type: string + provider_benchmark_id: + type: string + provider_id: + type: string + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - benchmark_id + - dataset_id + - scoring_functions + title: RegisterBenchmarkRequest RegisterDatasetRequest: type: object properties: @@ -4599,36 +5030,7 @@ components: - dataset_id - dataset_schema - url - RegisterEvalTaskRequest: - type: object - properties: - eval_task_id: - type: string - dataset_id: - type: string - scoring_functions: - type: array - items: - type: string - provider_eval_task_id: - type: string - provider_id: - type: string - metadata: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - additionalProperties: false - required: - - eval_task_id - - dataset_id - - scoring_functions + title: RegisterDatasetRequest RegisterModelRequest: type: object properties: @@ -4653,6 +5055,7 @@ components: additionalProperties: false required: - model_id + title: RegisterModelRequest RegisterScoringFunctionRequest: type: object properties: @@ -4673,6 +5076,7 @@ components: - scoring_fn_id - description - return_type + title: RegisterScoringFunctionRequest RegisterShieldRequest: type: object properties: @@ -4695,6 +5099,7 @@ components: additionalProperties: false required: - shield_id + title: RegisterShieldRequest RegisterToolGroupRequest: type: object properties: @@ -4718,6 +5123,7 @@ components: required: - toolgroup_id - provider_id + title: RegisterToolGroupRequest RegisterVectorDbRequest: type: object properties: @@ -4735,22 +5141,16 @@ components: required: - vector_db_id - embedding_model + title: RegisterVectorDbRequest RunEvalRequest: type: object properties: task_config: - $ref: '#/components/schemas/EvalTaskConfig' + $ref: '#/components/schemas/BenchmarkConfig' additionalProperties: false required: - task_config - Job: - type: object - properties: - job_id: - type: string - additionalProperties: false - required: - - job_id + title: RunEvalRequest RunShieldRequest: type: object properties: @@ -4775,12 +5175,14 @@ components: - shield_id - messages - params + title: RunShieldRequest RunShieldResponse: type: object properties: violation: $ref: '#/components/schemas/SafetyViolation' additionalProperties: false + title: RunShieldResponse SaveSpansToDatasetRequest: type: object properties: @@ -4801,6 +5203,7 @@ components: - attribute_filters - attributes_to_save - dataset_id + title: SaveSpansToDatasetRequest ScoreRequest: type: object properties: @@ -4826,6 +5229,7 @@ components: required: - input_rows - scoring_functions + title: ScoreRequest ScoreResponse: type: object properties: @@ -4836,6 +5240,7 @@ components: additionalProperties: false required: - results + title: ScoreResponse ScoreBatchRequest: type: object properties: @@ -4854,6 +5259,7 @@ components: - dataset_id - scoring_functions - save_results_dataset + title: ScoreBatchRequest ScoreBatchResponse: type: object properties: @@ -4866,6 +5272,7 @@ components: additionalProperties: false required: - results + title: ScoreBatchResponse AlgorithmConfig: oneOf: - $ref: '#/components/schemas/LoraFinetuningConfig' @@ -4908,6 +5315,7 @@ components: - apply_lora_to_output - rank - alpha + title: LoraFinetuningConfig QATFinetuningConfig: type: object properties: @@ -4924,6 +5332,7 @@ components: - type - quantizer_name - group_size + title: QATFinetuningConfig SupervisedFineTuneRequest: type: object properties: @@ -4964,6 +5373,7 @@ components: - hyperparam_search_config - logger_config - model + title: SupervisedFineTuneRequest SyntheticDataGenerateRequest: type: object properties: @@ -4980,6 +5390,7 @@ components: - top_p - top_k_top_p - sigmoid + title: FilteringFunction description: The type of filtering function. model: type: string @@ -4987,6 +5398,7 @@ components: required: - dialogs - filtering_function + title: SyntheticDataGenerateRequest SyntheticDataGenerationResponse: type: object properties: @@ -5015,6 +5427,7 @@ components: additionalProperties: false required: - synthetic_data + title: SyntheticDataGenerationResponse description: >- Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold. @@ -5026,6 +5439,7 @@ components: additionalProperties: false required: - version + title: VersionInfo responses: {} security: - Default: [] @@ -5049,10 +5463,10 @@ tags: x-displayName: >- Agents API for creating and interacting with agentic systems. - name: BatchInference (Coming Soon) + - name: Benchmarks - name: DatasetIO - name: Datasets - name: Eval - - name: EvalTasks - name: Inference description: >- This API provides the raw interface to the underlying models. Two kinds of models @@ -5083,10 +5497,10 @@ x-tagGroups: tags: - Agents - BatchInference (Coming Soon) + - Benchmarks - DatasetIO - Datasets - Eval - - EvalTasks - Inference - Inspect - Models diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index abe537c8e..51ae945f4 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -324,7 +324,7 @@ "- vector_io\n", "container_image: null\n", "datasets: []\n", - "eval_tasks: []\n", + "benchmarks: []\n", "image_name: together\n", "metadata_store:\n", " db_path: /Users/ashwin/.llama/distributions/together/registry.db\n", @@ -508,7 +508,7 @@ "- vector_io\n", "container_image: null\n", "datasets: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", - "eval_tasks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", + "benchmarks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", "image_name: together\n", "metadata_store:\n", " db_path: \u001b[35m/Users/ashwin/.llama/distributions/together/\u001b[0m\u001b[95mregistry.db\u001b[0m\n", @@ -3419,22 +3419,22 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "865fc5a8", - "metadata": {}, - "outputs": [], - "source": [ - "!pip install llama-stack-client==0.1.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "44e05e16", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " % Total % Received % Xferd Average Speed Time Time Time Current\n", + " Dload Upload Total Spent Left Speed\n", + "100 275k 100 275k 0 0 780k 0 --:--:-- --:--:-- --:--:-- 780k\n" + ] + } + ], "source": [ - "!wget https://raw.githubusercontent.com/meta-llama/llama-models/refs/heads/main/Llama_Repo.jpeg" + "!curl -O https://raw.githubusercontent.com/meta-llama/llama-models/refs/heads/main/Llama_Repo.jpeg" ] }, { @@ -3444,6 +3444,7 @@ "metadata": {}, "outputs": [], "source": [ + "# NBVAL_SKIP\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "\n", @@ -3580,6 +3581,7 @@ " model=LLAMA32_11B_INSTRUCT,\n", " instructions=\"You are a helpful assistant\",\n", " enable_session_persistence=False,\n", + " toolgroups=[],\n", " )\n", "\n", " agent = Agent(client, agent_config)\n", @@ -3630,7 +3632,7 @@ "provenance": [] }, "kernelspec": { - "display_name": "toolchain", + "display_name": "master", "language": "python", "name": "python3" }, @@ -3644,7 +3646,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.16" }, "widgets": { "application/vnd.jupyter.widget-state+json": { diff --git a/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb b/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb index 84da25246..8eecf84ab 100644 --- a/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb +++ b/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb @@ -370,7 +370,7 @@ "- tool_runtime\n", "datasets: []\n", "container_image: null\n", - "eval_tasks: []\n", + "benchmarks: []\n", "image_name: together\n", "memory_banks: []\n", "metadata_store:\n", @@ -551,7 +551,7 @@ "- tool_runtime\n", "datasets: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", "container_image: null\n", - "eval_tasks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", + "benchmarks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", "image_name: together\n", "memory_banks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", "metadata_store:\n", diff --git a/docs/openapi_generator/README.md b/docs/openapi_generator/README.md index 9d407905d..e98cfaf1b 100644 --- a/docs/openapi_generator/README.md +++ b/docs/openapi_generator/README.md @@ -1,4 +1,4 @@ -The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/[]/api/endpoints.py` using the `generate.py` utility. +The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/distribution/server/endpoints.py` using the `generate.py` utility. Please install the following packages before running the script: @@ -6,4 +6,4 @@ Please install the following packages before running the script: pip install python-openapi json-strong-typing fire PyYAML llama-models ``` -Then simply run `sh run_openapi_generator.sh ` +Then simply run `sh run_openapi_generator.sh` diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py index 48109e5d8..dcbee7d2f 100644 --- a/docs/openapi_generator/generate.py +++ b/docs/openapi_generator/generate.py @@ -16,18 +16,6 @@ from pathlib import Path import fire import ruamel.yaml as yaml -from llama_models import schema_utils - -# We do some monkey-patching to ensure our definitions only use the minimal -# (json_schema_type, webmethod) definitions from the llama_models package. For -# generation though, we need the full definitions and implementations from the -# (json-strong-typing) package. - -from .strong_typing.schema import json_schema_type, register_schema - -schema_utils.json_schema_type = json_schema_type -schema_utils.register_schema = register_schema - from llama_stack.apis.version import LLAMA_STACK_API_VERSION # noqa: E402 from llama_stack.distribution.stack import LlamaStack # noqa: E402 diff --git a/docs/openapi_generator/pyopenapi/generator.py b/docs/openapi_generator/pyopenapi/generator.py index a0385cae0..60cd7a242 100644 --- a/docs/openapi_generator/pyopenapi/generator.py +++ b/docs/openapi_generator/pyopenapi/generator.py @@ -10,9 +10,9 @@ import typing from dataclasses import make_dataclass from typing import Any, Dict, Set, Union -from ..strong_typing.core import JsonType -from ..strong_typing.docstring import Docstring, parse_type -from ..strong_typing.inspection import ( +from llama_stack.strong_typing.core import JsonType +from llama_stack.strong_typing.docstring import Docstring, parse_type +from llama_stack.strong_typing.inspection import ( is_generic_list, is_type_optional, is_type_union, @@ -20,15 +20,15 @@ from ..strong_typing.inspection import ( unwrap_optional_type, unwrap_union_types, ) -from ..strong_typing.name import python_type_to_name -from ..strong_typing.schema import ( +from llama_stack.strong_typing.name import python_type_to_name +from llama_stack.strong_typing.schema import ( get_schema_identifier, JsonSchemaGenerator, register_schema, Schema, SchemaOptions, ) -from ..strong_typing.serialization import json_dump_string, object_to_json +from llama_stack.strong_typing.serialization import json_dump_string, object_to_json from .operations import ( EndpointOperation, @@ -647,6 +647,7 @@ class Generator: description = "\n".join( filter(None, [doc_string.short_description, doc_string.long_description]) ) + return Operation( tags=[op.defining_class.__name__], summary=None, @@ -656,6 +657,7 @@ class Generator: requestBody=requestBody, responses=responses, callbacks=callbacks, + deprecated=True if "DEPRECATED" in op.func_name else None, security=[] if op.public else None, ) diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py index bf4d35c87..88a403182 100644 --- a/docs/openapi_generator/pyopenapi/operations.py +++ b/docs/openapi_generator/pyopenapi/operations.py @@ -15,7 +15,7 @@ from llama_stack.apis.version import LLAMA_STACK_API_VERSION from termcolor import colored -from ..strong_typing.inspection import get_signature +from llama_stack.strong_typing.inspection import get_signature def split_prefix( diff --git a/docs/openapi_generator/pyopenapi/specification.py b/docs/openapi_generator/pyopenapi/specification.py index 4b54295c5..9e5363b4a 100644 --- a/docs/openapi_generator/pyopenapi/specification.py +++ b/docs/openapi_generator/pyopenapi/specification.py @@ -9,7 +9,7 @@ import enum from dataclasses import dataclass from typing import Any, ClassVar, Dict, List, Optional, Union -from ..strong_typing.schema import JsonType, Schema, StrictJsonType +from llama_stack.strong_typing.schema import JsonType, Schema, StrictJsonType URL = str @@ -117,6 +117,7 @@ class Operation: requestBody: Optional[RequestBody] = None callbacks: Optional[Dict[str, "Callback"]] = None security: Optional[List["SecurityRequirement"]] = None + deprecated: Optional[bool] = None @dataclass diff --git a/docs/openapi_generator/pyopenapi/utility.py b/docs/openapi_generator/pyopenapi/utility.py index 54f10d473..f134aab4b 100644 --- a/docs/openapi_generator/pyopenapi/utility.py +++ b/docs/openapi_generator/pyopenapi/utility.py @@ -9,7 +9,7 @@ import typing from pathlib import Path from typing import TextIO -from ..strong_typing.schema import object_to_json, StrictJsonType +from llama_stack.strong_typing.schema import object_to_json, StrictJsonType from .generator import Generator from .options import Options diff --git a/docs/source/building_applications/evals.md b/docs/source/building_applications/evals.md index c4cb476e4..f28e0d5fd 100644 --- a/docs/source/building_applications/evals.md +++ b/docs/source/building_applications/evals.md @@ -41,14 +41,14 @@ system_message = { "content": SYSTEM_PROMPT_TEMPLATE, } -client.eval_tasks.register( - eval_task_id="meta-reference::mmmu", +client.benchmarks.register( + benchmark_id="meta-reference::mmmu", dataset_id=f"mmmu-{subset}-{split}", scoring_functions=["basic::regex_parser_multiple_choice_answer"], ) response = client.eval.evaluate_rows( - task_id="meta-reference::mmmu", + benchmark_id="meta-reference::mmmu", input_rows=eval_rows, scoring_functions=["basic::regex_parser_multiple_choice_answer"], task_config={ @@ -99,14 +99,14 @@ eval_rows = client.datasetio.get_rows_paginated( ``` ```python -client.eval_tasks.register( - eval_task_id="meta-reference::simpleqa", +client.benchmarks.register( + benchmark_id="meta-reference::simpleqa", dataset_id=simpleqa_dataset_id, scoring_functions=["llm-as-judge::405b-simpleqa"], ) response = client.eval.evaluate_rows( - task_id="meta-reference::simpleqa", + benchmark_id="meta-reference::simpleqa", input_rows=eval_rows.rows, scoring_functions=["llm-as-judge::405b-simpleqa"], task_config={ @@ -156,7 +156,7 @@ agent_config = { } response = client.eval.evaluate_rows( - task_id="meta-reference::simpleqa", + benchmark_id="meta-reference::simpleqa", input_rows=eval_rows.rows, scoring_functions=["llm-as-judge::405b-simpleqa"], task_config={ diff --git a/docs/source/building_applications/evaluation.md b/docs/source/building_applications/evaluation.md index 91e5c552b..ad220f751 100644 --- a/docs/source/building_applications/evaluation.md +++ b/docs/source/building_applications/evaluation.md @@ -10,15 +10,15 @@ Here's how to set up basic evaluation: ```python # Create an evaluation task -response = client.eval_tasks.register( - eval_task_id="my_eval", +response = client.benchmarks.register( + benchmark_id="my_eval", dataset_id="my_dataset", scoring_functions=["accuracy", "relevance"], ) # Run evaluation job = client.eval.run_eval( - task_id="my_eval", + benchmark_id="my_eval", task_config={ "type": "app", "eval_candidate": {"type": "agent", "config": agent_config}, @@ -26,5 +26,5 @@ job = client.eval.run_eval( ) # Get results -result = client.eval.job_result(task_id="my_eval", job_id=job.job_id) +result = client.eval.job_result(benchmark_id="my_eval", job_id=job.job_id) ``` diff --git a/docs/source/concepts/evaluation_concepts.md b/docs/source/concepts/evaluation_concepts.md index 399d99d92..3ca4b0ac8 100644 --- a/docs/source/concepts/evaluation_concepts.md +++ b/docs/source/concepts/evaluation_concepts.md @@ -5,7 +5,7 @@ The Llama Stack Evaluation flow allows you to run evaluations on your GenAI appl We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications. - `/datasetio` + `/datasets` API - `/scoring` + `/scoring_functions` API -- `/eval` + `/eval_tasks` API +- `/eval` + `/benchmarks` API This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing). @@ -21,7 +21,7 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo - **Scoring**: evaluate outputs of the system. - Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics. - **Eval**: generate outputs (via Inference or Agents) and perform scoring. - - Associated with `EvalTask` resource. + - Associated with `Benchmark` resource. Use the following decision tree to decide how to use LlamaStack Evaluation flow. diff --git a/docs/source/concepts/index.md b/docs/source/concepts/index.md index 1437ec623..403e47c48 100644 --- a/docs/source/concepts/index.md +++ b/docs/source/concepts/index.md @@ -42,7 +42,7 @@ Some of these APIs are associated with a set of **Resources**. Here is the mappi - **Tool Runtime** is associated with `ToolGroup` resources. - **DatasetIO** is associated with `Dataset` resources. - **Scoring** is associated with `ScoringFunction` resources. -- **Eval** is associated with `Model` and `EvalTask` resources. +- **Eval** is associated with `Model` and `Benchmark` resources. Furthermore, we allow these resources to be **federated** across multiple providers. For example, you may have some Llama models served by Fireworks while others are served by AWS Bedrock. Regardless, they will all work seamlessly with the same uniform Inference API provided by Llama Stack. diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md index 90239cb4e..9cb1a402f 100644 --- a/docs/source/distributions/building_distro.md +++ b/docs/source/distributions/building_distro.md @@ -23,7 +23,8 @@ The main points to consider are: ``` llama stack build -h -usage: llama stack build [-h] [--config CONFIG] [--template TEMPLATE] [--list-templates | --no-list-templates] [--image-type {conda,container,venv}] [--image-name IMAGE_NAME] +usage: llama stack build [-h] [--config CONFIG] [--template TEMPLATE] [--list-templates] + [--image-type {conda,container,venv}] [--image-name IMAGE_NAME] [--print-deps-only] Build a Llama stack container @@ -32,14 +33,14 @@ options: --config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively --template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates - --list-templates, --no-list-templates - Show the available templates for building a Llama Stack distribution (default: False) + --list-templates Show the available templates for building a Llama Stack distribution --image-type {conda,container,venv} Image Type to use for the build. This can be either conda or container or venv. If not specified, will use the image type from the template config. --image-name IMAGE_NAME [for image-type=conda] Name of the conda environment to use for the build. If not specified, currently active Conda environment will be used. If no Conda environment is active, you must specify a name. + --print-deps-only Print the dependencies for the stack only, without building the stack ``` After this step is complete, a file named `-build.yaml` and template file `-run.yaml` will be generated and saved at the output file path specified at the end of the command. diff --git a/docs/source/index.md b/docs/source/index.md index 2834f5641..cb2355bfd 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -2,7 +2,7 @@ ```{admonition} News :class: tip -Llama Stack 0.1.2 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.2) for more details. +Llama Stack 0.1.3 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.3) for more details. ``` # Llama Stack diff --git a/docs/source/playground/index.md b/docs/source/playground/index.md index d74bf1a03..9691609ab 100644 --- a/docs/source/playground/index.md +++ b/docs/source/playground/index.md @@ -64,7 +64,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie ``` ```bash - $ llama-stack-client eval_tasks register \ + $ llama-stack-client benchmarks register \ --eval-task-id meta-reference-mmlu \ --provider-id meta-reference \ --dataset-id mmlu \ @@ -86,7 +86,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie - Under the hood, it uses Llama Stack's `/providers` API to get information about the providers. - **API Resources**: Inspect Llama Stack API resources - - This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `eval_tasks`, `shields`). + - This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `benchmarks`, `shields`). - Under the hood, it uses Llama Stack's `//list` API to get information about each resources. - Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources. diff --git a/docs/source/references/evals_reference/index.md b/docs/source/references/evals_reference/index.md index 86f66208a..71dbb47e5 100644 --- a/docs/source/references/evals_reference/index.md +++ b/docs/source/references/evals_reference/index.md @@ -5,7 +5,7 @@ The Llama Stack Evaluation flow allows you to run evaluations on your GenAI appl We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications. - `/datasetio` + `/datasets` API - `/scoring` + `/scoring_functions` API -- `/eval` + `/eval_tasks` API +- `/eval` + `/benchmarks` API This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing). @@ -21,7 +21,7 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo - **Scoring**: evaluate outputs of the system. - Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics. - **Eval**: generate outputs (via Inference or Agents) and perform scoring. - - Associated with `EvalTask` resource. + - Associated with `Benchmark` resource. Use the following decision tree to decide how to use LlamaStack Evaluation flow. @@ -77,14 +77,14 @@ system_message = { "content": SYSTEM_PROMPT_TEMPLATE, } -client.eval_tasks.register( - eval_task_id="meta-reference::mmmu", +client.benchmarks.register( + benchmark_id="meta-reference::mmmu", dataset_id=f"mmmu-{subset}-{split}", scoring_functions=["basic::regex_parser_multiple_choice_answer"], ) response = client.eval.evaluate_rows( - task_id="meta-reference::mmmu", + benchmark_id="meta-reference::mmmu", input_rows=eval_rows, scoring_functions=["basic::regex_parser_multiple_choice_answer"], task_config={ @@ -135,14 +135,14 @@ eval_rows = client.datasetio.get_rows_paginated( ``` ```python -client.eval_tasks.register( - eval_task_id="meta-reference::simpleqa", +client.benchmarks.register( + benchmark_id="meta-reference::simpleqa", dataset_id=simpleqa_dataset_id, scoring_functions=["llm-as-judge::405b-simpleqa"], ) response = client.eval.evaluate_rows( - task_id="meta-reference::simpleqa", + benchmark_id="meta-reference::simpleqa", input_rows=eval_rows.rows, scoring_functions=["llm-as-judge::405b-simpleqa"], task_config={ @@ -192,7 +192,7 @@ agent_config = { } response = client.eval.evaluate_rows( - task_id="meta-reference::simpleqa", + benchmark_id="meta-reference::simpleqa", input_rows=eval_rows.rows, scoring_functions=["llm-as-judge::405b-simpleqa"], task_config={ @@ -281,7 +281,7 @@ The following examples give the quick steps to start running evaluations using t #### Benchmark Evaluation CLI Usage: There are 2 inputs necessary for running a benchmark eval -- `eval-task-id`: the identifier associated with the eval task. Each `EvalTask` is parametrized by +- `eval-task-id`: the identifier associated with the eval task. Each `Benchmark` is parametrized by - `dataset_id`: the identifier associated with the dataset. - `List[scoring_function_id]`: list of scoring function identifiers. - `eval-task-config`: specifies the configuration of the model / agent to evaluate on. @@ -289,7 +289,7 @@ Usage: There are 2 inputs necessary for running a benchmark eval ``` llama-stack-client eval run_benchmark \ ---eval-task-config ~/eval_task_config.json \ +--eval-task-config ~/benchmark_config.json \ --visualize ``` @@ -309,15 +309,15 @@ llama-stack-client eval run_scoring ... --dataset-id --scoring-functions [ ...] [--provider-id ] [--provider-eval-task-id ] [--metadata ] +$ llama-stack-client benchmarks register --eval-task-id --dataset-id --scoring-functions [ ...] [--provider-id ] [--provider-eval-task-id ] [--metadata ] ``` Options: @@ -191,7 +191,7 @@ Options: - `--num-examples`: Optional. Number of examples to evaluate (useful for debugging) - `--visualize`: Optional flag. If set, visualizes evaluation results after completion -Example eval_task_config.json: +Example benchmark_config.json: ```json { "type": "benchmark", diff --git a/docs/source/references/python_sdk_reference/index.md b/docs/source/references/python_sdk_reference/index.md index 8a06e2244..9d1130422 100644 --- a/docs/source/references/python_sdk_reference/index.md +++ b/docs/source/references/python_sdk_reference/index.md @@ -181,8 +181,8 @@ from llama_stack_client.types import EvaluateResponse, Job Methods: -- client.eval.evaluate_rows(task_id, \*\*params) -> EvaluateResponse -- client.eval.run_eval(task_id, \*\*params) -> Job +- client.eval.evaluate_rows(benchmark_id, \*\*params) -> EvaluateResponse +- client.eval.run_eval(benchmark_id, \*\*params) -> Job ### Jobs @@ -194,9 +194,9 @@ from llama_stack_client.types.eval import JobStatusResponse Methods: -- client.eval.jobs.retrieve(job_id, \*, task_id) -> EvaluateResponse -- client.eval.jobs.cancel(job_id, \*, task_id) -> None -- client.eval.jobs.status(job_id, \*, task_id) -> Optional[JobStatusResponse] +- client.eval.jobs.retrieve(job_id, \*, benchmark_id) -> EvaluateResponse +- client.eval.jobs.cancel(job_id, \*, benchmark_id) -> None +- client.eval.jobs.status(job_id, \*, benchmark_id) -> Optional[JobStatusResponse] ## Inspect @@ -443,20 +443,20 @@ Methods: - client.scoring_functions.list() -> ScoringFunctionListResponse - client.scoring_functions.register(\*\*params) -> None -## EvalTasks +## Benchmarks Types: ```python from llama_stack_client.types import ( - EvalTask, - ListEvalTasksResponse, - EvalTaskListResponse, + Benchmark, + ListBenchmarksResponse, + BenchmarkListResponse, ) ``` Methods: -- client.eval_tasks.retrieve(eval_task_id) -> Optional[EvalTask] -- client.eval_tasks.list() -> EvalTaskListResponse -- client.eval_tasks.register(\*\*params) -> None +- client.benchmarks.retrieve(benchmark_id) -> Optional[Benchmark] +- client.benchmarks.list() -> BenchmarkListResponse +- client.benchmarks.register(\*\*params) -> None diff --git a/docs/zero_to_hero_guide/Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb b/docs/zero_to_hero_guide/Tool_Calling101_Using_Together_Llama_Stack_Server.ipynb similarity index 100% rename from docs/zero_to_hero_guide/Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb rename to docs/zero_to_hero_guide/Tool_Calling101_Using_Together_Llama_Stack_Server.ipynb diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 106d34584..367648ded 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -19,7 +19,6 @@ from typing import ( runtime_checkable, ) -from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent @@ -38,6 +37,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod class Attachment(BaseModel): @@ -179,7 +179,7 @@ class AgentConfigCommon(BaseModel): class AgentConfig(AgentConfigCommon): model: str instructions: str - enable_session_persistence: bool + enable_session_persistence: Optional[bool] = False response_format: Optional[ResponseFormat] = None diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py deleted file mode 100644 index 835ce4cee..000000000 --- a/llama_stack/apis/agents/event_logger.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Optional - -from llama_models.llama3.api.datatypes import ToolPromptFormat -from llama_models.llama3.api.tool_utils import ToolUtils -from termcolor import cprint - -from llama_stack.apis.agents import AgentTurnResponseEventType, StepType -from llama_stack.apis.common.content_types import ToolCallParseStatus -from llama_stack.apis.inference import ToolResponseMessage -from llama_stack.providers.utils.inference.prompt_adapter import ( - interleaved_content_as_str, -) - - -class LogEvent: - def __init__( - self, - role: Optional[str] = None, - content: str = "", - end: str = "\n", - color="white", - ): - self.role = role - self.content = content - self.color = color - self.end = "\n" if end is None else end - - def __str__(self): - if self.role is not None: - return f"{self.role}> {self.content}" - else: - return f"{self.content}" - - def print(self, flush=True): - cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush) - - -EventType = AgentTurnResponseEventType - - -class EventLogger: - async def log( - self, - event_generator, - stream=True, - tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, - ): - previous_event_type = None - previous_step_type = None - - async for chunk in event_generator: - if not hasattr(chunk, "event"): - # Need to check for custom tool first - # since it does not produce event but instead - # a Message - if isinstance(chunk, ToolResponseMessage): - yield ( - chunk, - LogEvent(role="CustomTool", content=chunk.content, color="grey"), - ) - continue - - event = chunk.event - event_type = event.payload.event_type - if event_type in { - EventType.turn_start.value, - EventType.turn_complete.value, - }: - # Currently not logging any turn realted info - yield event, None - continue - - step_type = event.payload.step_type - # handle safety - if step_type == StepType.shield_call and event_type == EventType.step_complete.value: - violation = event.payload.step_details.violation - if not violation: - yield ( - event, - LogEvent(role=step_type, content="No Violation", color="magenta"), - ) - else: - yield ( - event, - LogEvent( - role=step_type, - content=f"{violation.metadata} {violation.user_message}", - color="red", - ), - ) - - # handle inference - if step_type == StepType.inference: - if stream: - if event_type == EventType.step_start.value: - # TODO: Currently this event is never received - yield ( - event, - LogEvent(role=step_type, content="", end="", color="yellow"), - ) - elif event_type == EventType.step_progress.value: - # HACK: if previous was not step/event was not inference's step_progress - # this is the first time we are getting model inference response - # aka equivalent to step_start for inference. Hence, - # start with "Model>". - if ( - previous_event_type != EventType.step_progress.value - and previous_step_type != StepType.inference - ): - yield ( - event, - LogEvent(role=step_type, content="", end="", color="yellow"), - ) - - delta = event.payload.delta - if delta.type == "tool_call": - if delta.parse_status == ToolCallParseStatus.succeeded: - yield ( - event, - LogEvent( - role=None, - content=delta.tool_call, - end="", - color="cyan", - ), - ) - else: - yield ( - event, - LogEvent( - role=None, - content=delta.text, - end="", - color="yellow", - ), - ) - else: - # step_complete - yield event, LogEvent(role=None, content="") - - else: - # Not streaming - if event_type == EventType.step_complete.value: - response = event.payload.step_details.model_response - if response.tool_calls: - content = ToolUtils.encode_tool_call(response.tool_calls[0], tool_prompt_format) - else: - content = response.content - yield ( - event, - LogEvent( - role=step_type, - content=content, - color="yellow", - ), - ) - - # handle tool_execution - if ( - step_type == StepType.tool_execution - and - # Only print tool calls and responses at the step_complete event - event_type == EventType.step_complete.value - ): - details = event.payload.step_details - for t in details.tool_calls: - yield ( - event, - LogEvent( - role=step_type, - content=f"Tool:{t.tool_name} Args:{t.arguments}", - color="green", - ), - ) - for r in details.tool_responses: - yield ( - event, - LogEvent( - role=step_type, - content=f"Tool:{r.tool_name} Response:{r.content}", - color="green", - ), - ) - - if step_type == StepType.memory_retrieval and event_type == EventType.step_complete.value: - details = event.payload.step_details - inserted_context = interleaved_content_as_str(details.inserted_context) - content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}" - - yield ( - event, - LogEvent( - role=step_type, - content=content, - color="cyan", - ), - ) - - previous_event_type = event_type - previous_step_type = step_type diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py index 413c81c5a..0fa5c78ce 100644 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -6,7 +6,6 @@ from typing import List, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.inference import ( @@ -21,6 +20,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.schema_utils import json_schema_type, webmethod @json_schema_type diff --git a/llama_stack/apis/eval_tasks/__init__.py b/llama_stack/apis/benchmarks/__init__.py similarity index 81% rename from llama_stack/apis/eval_tasks/__init__.py rename to llama_stack/apis/benchmarks/__init__.py index 7ca216706..f8f564957 100644 --- a/llama_stack/apis/eval_tasks/__init__.py +++ b/llama_stack/apis/benchmarks/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .eval_tasks import * # noqa: F401 F403 +from .benchmarks import * # noqa: F401 F403 diff --git a/llama_stack/apis/benchmarks/benchmarks.py b/llama_stack/apis/benchmarks/benchmarks.py new file mode 100644 index 000000000..91b1ca927 --- /dev/null +++ b/llama_stack/apis/benchmarks/benchmarks.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable + +from pydantic import BaseModel, Field + +from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.schema_utils import json_schema_type, webmethod + + +class CommonBenchmarkFields(BaseModel): + dataset_id: str + scoring_functions: List[str] + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Metadata for this evaluation task", + ) + + +@json_schema_type +class Benchmark(CommonBenchmarkFields, Resource): + type: Literal[ResourceType.benchmark.value] = ResourceType.benchmark.value + + @property + def benchmark_id(self) -> str: + return self.identifier + + @property + def provider_benchmark_id(self) -> str: + return self.provider_resource_id + + +class BenchmarkInput(CommonBenchmarkFields, BaseModel): + benchmark_id: str + provider_id: Optional[str] = None + provider_benchmark_id: Optional[str] = None + + +class ListBenchmarksResponse(BaseModel): + data: List[Benchmark] + + +@runtime_checkable +class Benchmarks(Protocol): + @webmethod(route="/eval/benchmarks", method="GET") + async def list_benchmarks(self) -> ListBenchmarksResponse: ... + + @webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET") + async def get_benchmark( + self, + benchmark_id: str, + ) -> Optional[Benchmark]: ... + + @webmethod(route="/eval/benchmarks", method="POST") + async def register_benchmark( + self, + benchmark_id: str, + dataset_id: str, + scoring_functions: List[str], + provider_benchmark_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: ... + + @webmethod(route="/eval-tasks", method="GET") + async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse: ... + + @webmethod(route="/eval-tasks/{eval_task_id}", method="GET") + async def DEPRECATED_get_eval_task( + self, + eval_task_id: str, + ) -> Optional[Benchmark]: ... + + @webmethod(route="/eval-tasks", method="POST") + async def DEPRECATED_register_eval_task( + self, + eval_task_id: str, + dataset_id: str, + scoring_functions: List[str], + provider_benchmark_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: ... diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index e648f9a19..0d0afa894 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -7,10 +7,11 @@ from enum import Enum from typing import Annotated, List, Literal, Optional, Union -from llama_models.llama3.api.datatypes import ToolCall -from llama_models.schema_utils import json_schema_type, register_schema from pydantic import BaseModel, Field, model_validator +from llama_stack.models.llama.datatypes import ToolCall +from llama_stack.schema_utils import json_schema_type, register_schema + @json_schema_type class URL(BaseModel): diff --git a/llama_stack/apis/common/deployment_types.py b/llama_stack/apis/common/deployment_types.py index 16a5c8ad6..83eea28a2 100644 --- a/llama_stack/apis/common/deployment_types.py +++ b/llama_stack/apis/common/deployment_types.py @@ -7,10 +7,10 @@ from enum import Enum from typing import Any, Dict, Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel from llama_stack.apis.common.content_types import URL +from llama_stack.schema_utils import json_schema_type @json_schema_type diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py index c945bd8ff..bc070017b 100644 --- a/llama_stack/apis/common/job_types.py +++ b/llama_stack/apis/common/job_types.py @@ -5,9 +5,10 @@ # the root directory of this source tree. from enum import Enum -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel +from llama_stack.schema_utils import json_schema_type + @json_schema_type class Job(BaseModel): diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py index b4bd1b0c6..d6c6c6919 100644 --- a/llama_stack/apis/common/training_types.py +++ b/llama_stack/apis/common/training_types.py @@ -7,9 +7,10 @@ from datetime import datetime from typing import Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel +from llama_stack.schema_utils import json_schema_type + @json_schema_type class PostTrainingMetric(BaseModel): diff --git a/llama_stack/apis/common/type_system.py b/llama_stack/apis/common/type_system.py index fa9c5e92e..139ae8875 100644 --- a/llama_stack/apis/common/type_system.py +++ b/llama_stack/apis/common/type_system.py @@ -6,10 +6,11 @@ from typing import Literal, Union -from llama_models.schema_utils import json_schema_type, register_schema from pydantic import BaseModel, Field from typing_extensions import Annotated +from llama_stack.schema_utils import json_schema_type, register_schema + @json_schema_type class StringType(BaseModel): diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index 2ad7aab73..d85d22876 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -6,10 +6,10 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.datasets import Dataset +from llama_stack.schema_utils import json_schema_type, webmethod @json_schema_type diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 5e2b38697..fe9d30e2a 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -6,12 +6,12 @@ from typing import Any, Dict, List, Literal, Optional, Protocol -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.schema_utils import json_schema_type, webmethod class CommonDatasetFields(BaseModel): diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index ccc395b80..6df93052c 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -6,7 +6,7 @@ from enum import Enum -from llama_models.schema_utils import json_schema_type +from llama_stack.schema_utils import json_schema_type @json_schema_type @@ -28,7 +28,7 @@ class Api(Enum): vector_dbs = "vector_dbs" datasets = "datasets" scoring_functions = "scoring_functions" - eval_tasks = "eval_tasks" + benchmarks = "benchmarks" tool_groups = "tool_groups" # built-in API diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index ae13a5bd9..e2ff4458e 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -6,7 +6,6 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, Union -from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -15,6 +14,7 @@ from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.scoring import ScoringResult from llama_stack.apis.scoring_functions import ScoringFnParams +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @json_schema_type @@ -38,19 +38,9 @@ EvalCandidate = register_schema( @json_schema_type -class BenchmarkEvalTaskConfig(BaseModel): +class BenchmarkConfig(BaseModel): type: Literal["benchmark"] = "benchmark" eval_candidate: EvalCandidate - num_examples: Optional[int] = Field( - description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated", - default=None, - ) - - -@json_schema_type -class AppEvalTaskConfig(BaseModel): - type: Literal["app"] = "app" - eval_candidate: EvalCandidate scoring_params: Dict[str, ScoringFnParams] = Field( description="Map between scoring function id and parameters for each scoring function you want to run", default_factory=dict, @@ -62,12 +52,6 @@ class AppEvalTaskConfig(BaseModel): # we could optinally add any specific dataset config here -EvalTaskConfig = register_schema( - Annotated[Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")], - name="EvalTaskConfig", -) - - @json_schema_type class EvaluateResponse(BaseModel): generations: List[Dict[str, Any]] @@ -76,27 +60,52 @@ class EvaluateResponse(BaseModel): class Eval(Protocol): - @webmethod(route="/eval/tasks/{task_id}/jobs", method="POST") + @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST") async def run_eval( + self, + benchmark_id: str, + task_config: BenchmarkConfig, + ) -> Job: ... + + @webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST") + async def evaluate_rows( + self, + benchmark_id: str, + input_rows: List[Dict[str, Any]], + scoring_functions: List[str], + task_config: BenchmarkConfig, + ) -> EvaluateResponse: ... + + @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") + async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: ... + + @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE") + async def job_cancel(self, benchmark_id: str, job_id: str) -> None: ... + + @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET") + async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ... + + @webmethod(route="/eval/tasks/{task_id}/jobs", method="POST") + async def DEPRECATED_run_eval( self, task_id: str, - task_config: EvalTaskConfig, + task_config: BenchmarkConfig, ) -> Job: ... @webmethod(route="/eval/tasks/{task_id}/evaluations", method="POST") - async def evaluate_rows( + async def DEPRECATED_evaluate_rows( self, task_id: str, input_rows: List[Dict[str, Any]], scoring_functions: List[str], - task_config: EvalTaskConfig, + task_config: BenchmarkConfig, ) -> EvaluateResponse: ... @webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="GET") - async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ... + async def DEPRECATED_job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ... @webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="DELETE") - async def job_cancel(self, task_id: str, job_id: str) -> None: ... + async def DEPRECATED_job_cancel(self, task_id: str, job_id: str) -> None: ... @webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}/result", method="GET") - async def job_result(self, job_id: str, task_id: str) -> EvaluateResponse: ... + async def DEPRECATED_job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ... diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py deleted file mode 100644 index a0a533055..000000000 --- a/llama_stack/apis/eval_tasks/eval_tasks.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable - -from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel, Field - -from llama_stack.apis.resource import Resource, ResourceType - - -class CommonEvalTaskFields(BaseModel): - dataset_id: str - scoring_functions: List[str] - metadata: Dict[str, Any] = Field( - default_factory=dict, - description="Metadata for this evaluation task", - ) - - -@json_schema_type -class EvalTask(CommonEvalTaskFields, Resource): - type: Literal[ResourceType.eval_task.value] = ResourceType.eval_task.value - - @property - def eval_task_id(self) -> str: - return self.identifier - - @property - def provider_eval_task_id(self) -> str: - return self.provider_resource_id - - -class EvalTaskInput(CommonEvalTaskFields, BaseModel): - eval_task_id: str - provider_id: Optional[str] = None - provider_eval_task_id: Optional[str] = None - - -class ListEvalTasksResponse(BaseModel): - data: List[EvalTask] - - -@runtime_checkable -class EvalTasks(Protocol): - @webmethod(route="/eval-tasks", method="GET") - async def list_eval_tasks(self) -> ListEvalTasksResponse: ... - - @webmethod(route="/eval-tasks/{eval_task_id}", method="GET") - async def get_eval_task( - self, - eval_task_id: str, - ) -> Optional[EvalTask]: ... - - @webmethod(route="/eval-tasks", method="POST") - async def register_eval_task( - self, - eval_task_id: str, - dataset_id: str, - scoring_functions: List[str], - provider_eval_task_id: Optional[str] = None, - provider_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> None: ... diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 9fccd3911..a3fb69477 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -17,7 +17,13 @@ from typing import ( runtime_checkable, ) -from llama_models.llama3.api.datatypes import ( +from pydantic import BaseModel, Field, field_validator +from typing_extensions import Annotated + +from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent +from llama_stack.apis.models import Model +from llama_stack.apis.telemetry.telemetry import MetricResponseMixin +from llama_stack.models.llama.datatypes import ( BuiltinTool, SamplingParams, StopReason, @@ -25,14 +31,8 @@ from llama_models.llama3.api.datatypes import ( ToolDefinition, ToolPromptFormat, ) -from llama_models.schema_utils import json_schema_type, register_schema, webmethod -from pydantic import BaseModel, Field, field_validator -from typing_extensions import Annotated - -from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent -from llama_stack.apis.models import Model -from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod class LogProbConfig(BaseModel): @@ -182,10 +182,12 @@ class ToolChoice(Enum): :cvar auto: The model may use tools if it determines that is appropriate. :cvar required: The model must use tools. + :cvar none: The model must not use tools. """ auto = "auto" required = "required" + none = "none" @json_schema_type @@ -326,7 +328,7 @@ class SystemMessageBehavior(Enum): class ToolConfig(BaseModel): """Configuration for tool use. - :param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. + :param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto. :param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. @@ -337,9 +339,16 @@ class ToolConfig(BaseModel): '{{function_definitions}}' to indicate where the function definitions should be inserted. """ - tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) + tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) - system_message_behavior: SystemMessageBehavior = Field(default=SystemMessageBehavior.append) + system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append) + + def model_post_init(self, __context: Any) -> None: + if isinstance(self.tool_choice, str): + try: + self.tool_choice = ToolChoice[self.tool_choice] + except KeyError: + pass # This is an internally used class diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py index cd51469c1..4a647a2d9 100644 --- a/llama_stack/apis/inspect/inspect.py +++ b/llama_stack/apis/inspect/inspect.py @@ -6,9 +6,10 @@ from typing import List, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel +from llama_stack.schema_utils import json_schema_type, webmethod + @json_schema_type class ProviderInfo(BaseModel): diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 7e6d9854f..64b9510ea 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -7,11 +7,11 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod class CommonModelFields(BaseModel): diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 8cd2979a8..ed15c6de4 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -8,13 +8,13 @@ from datetime import datetime from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, Union -from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.common.training_types import Checkpoint +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @json_schema_type diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index 145113a5d..70ec63c55 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -15,7 +15,7 @@ class ResourceType(Enum): vector_db = "vector_db" dataset = "dataset" scoring_function = "scoring_function" - eval_task = "eval_task" + benchmark = "benchmark" tool = "tool" tool_group = "tool_group" diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 513733d1e..fd2f0292c 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -7,12 +7,12 @@ from enum import Enum from typing import Any, Dict, List, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_stack.apis.inference import Message from llama_stack.apis.shields import Shield from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod @json_schema_type diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index 5bacaaf66..960149476 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -6,10 +6,10 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams +from llama_stack.schema_utils import json_schema_type, webmethod # mapping of metric to value ScoringResultRow = Dict[str, Any] diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index fece50fbd..52508d2ec 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -16,12 +16,12 @@ from typing import ( runtime_checkable, ) -from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod # Perhaps more structure can be imposed on these functions. Maybe they could be associated diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index ae316ee53..ec1179ac4 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -6,11 +6,11 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod class CommonShieldFields(BaseModel): diff --git a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py index a61fb0cf2..7b41192af 100644 --- a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +++ b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py @@ -7,10 +7,10 @@ from enum import Enum from typing import Any, Dict, List, Optional, Protocol, Union -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.inference import Message +from llama_stack.schema_utils import json_schema_type, webmethod class FilteringFunction(Enum): diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 63ae1dc73..d010a7e3b 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -17,11 +17,12 @@ from typing import ( runtime_checkable, ) -from llama_models.llama3.api.datatypes import Primitive -from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated +from llama_stack.models.llama.datatypes import Primitive +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod + # Add this constant near the top of the file, after the imports DEFAULT_TTL_DAYS = 7 diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index 2e6b43eb8..cff8eeefe 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -7,12 +7,12 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union -from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated, Protocol, runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @json_schema_type diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 2a407ca00..b83be127f 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -7,13 +7,13 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod from .rag_tool import RAGToolRuntime diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index 1da2c128c..9a4aa322f 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -6,11 +6,11 @@ from typing import List, Literal, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod @json_schema_type diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 8feeaa6d4..2bbb3bce8 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -10,12 +10,12 @@ # the root directory of this source tree. from typing import Any, Dict, List, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod class Chunk(BaseModel): diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index 3ea534277..af86f7243 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -16,8 +16,6 @@ from pathlib import Path from typing import Dict, List, Optional import httpx -from llama_models.datatypes import Model -from llama_models.sku_list import LlamaDownloadInfo from pydantic import BaseModel, ConfigDict from rich.console import Console from rich.progress import ( @@ -31,6 +29,8 @@ from rich.progress import ( from termcolor import cprint from llama_stack.cli.subcommand import Subcommand +from llama_stack.models.llama.datatypes import Model +from llama_stack.models.llama.sku_list import LlamaDownloadInfo class Download(Subcommand): @@ -56,7 +56,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--model-id", required=False, - help="See `llama model list` or `llama model list --show-all` for the list of available models", + help="See `llama model list` or `llama model list --show-all` for the list of available models. Specify multiple model IDs with commas, e.g. --model-id Llama3.2-1B,Llama3.2-3B", ) parser.add_argument( "--hf-token", @@ -83,8 +83,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None: type=str, required=False, default="*.safetensors", - help=""" -For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring + help="""For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring safetensors files to avoid downloading duplicate weights. """, ) @@ -454,7 +453,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): # Handle comma-separated model IDs model_ids = [model_id.strip() for model_id in args.model_id.split(",")] - from llama_models.sku_list import llama_meta_net_info, resolve_model + from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model from .model.safety_models import ( prompt_guard_download_info, diff --git a/llama_stack/cli/model/describe.py b/llama_stack/cli/model/describe.py index a25513633..d8f4e035c 100644 --- a/llama_stack/cli/model/describe.py +++ b/llama_stack/cli/model/describe.py @@ -7,11 +7,11 @@ import argparse import json -from llama_models.sku_list import resolve_model from termcolor import colored from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.table import print_table +from llama_stack.models.llama.sku_list import resolve_model class ModelDescribe(Subcommand): @@ -34,6 +34,7 @@ class ModelDescribe(Subcommand): "--model-id", type=str, required=True, + help="See `llama model list` or `llama model list --show-all` for the list of available models", ) def _run_model_describe_cmd(self, args: argparse.Namespace) -> None: diff --git a/llama_stack/cli/model/list.py b/llama_stack/cli/model/list.py index 9b5ebb1a5..e6bf2216a 100644 --- a/llama_stack/cli/model/list.py +++ b/llama_stack/cli/model/list.py @@ -6,10 +6,9 @@ import argparse -from llama_models.sku_list import all_registered_models - from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.table import print_table +from llama_stack.models.llama.sku_list import all_registered_models class ModelList(Subcommand): @@ -37,8 +36,8 @@ class ModelList(Subcommand): from .safety_models import prompt_guard_model_sku headers = [ - "Model Descriptor", - "Model ID", + "Model Descriptor(ID)", + "Hugging Face Repo", "Context Length", ] diff --git a/llama_stack/cli/model/prompt_format.py b/llama_stack/cli/model/prompt_format.py index 2e1e1601e..ea9596ba5 100644 --- a/llama_stack/cli/model/prompt_format.py +++ b/llama_stack/cli/model/prompt_format.py @@ -8,9 +8,8 @@ import argparse import textwrap from io import StringIO -from llama_models.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family - from llama_stack.cli.subcommand import Subcommand +from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family class ModelPromptFormat(Subcommand): diff --git a/llama_stack/cli/model/safety_models.py b/llama_stack/cli/model/safety_models.py index 2321c4615..c81783f60 100644 --- a/llama_stack/cli/model/safety_models.py +++ b/llama_stack/cli/model/safety_models.py @@ -6,11 +6,11 @@ from typing import Any, Dict, Optional -from llama_models.datatypes import CheckpointQuantizationFormat -from llama_models.llama3.api.datatypes import SamplingParams -from llama_models.sku_list import LlamaDownloadInfo from pydantic import BaseModel, ConfigDict, Field +from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams +from llama_stack.models.llama.sku_list import LlamaDownloadInfo + class PromptGuardModel(BaseModel): """Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed.""" diff --git a/llama_stack/cli/model/verify_download.py b/llama_stack/cli/model/verify_download.py index b8e6bf173..e7159c0aa 100644 --- a/llama_stack/cli/model/verify_download.py +++ b/llama_stack/cli/model/verify_download.py @@ -15,7 +15,7 @@ class ModelVerifyDownload(Subcommand): self.parser = subparsers.add_parser( "verify-download", prog="llama model verify-download", - description="Verify the downloaded checkpoints' checksums", + description="Verify the downloaded checkpoints' checksums for models downloaded from Meta", formatter_class=argparse.RawTextHelpFormatter, ) diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 729bd3ff1..7b17a960a 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -38,9 +38,8 @@ class StackBuild(Subcommand): self.parser.add_argument( "--list-templates", - type=bool, + action="store_true", default=False, - action=argparse.BooleanOptionalAction, help="Show the available templates for building a Llama Stack distribution", ) @@ -56,9 +55,8 @@ class StackBuild(Subcommand): "--image-name", type=str, help=textwrap.dedent( - """[for image-type=conda] Name of the conda environment to use for the build. If -not specified, currently active Conda environment will be used. If no Conda -environment is active, you must specify a name. + """[for image-type=conda|venv] Name of the conda or virtual environment to use for +the build. If not specified, currently active Conda environment will be used if found. """ ), default=None, diff --git a/llama_stack/cli/stack/configure.py b/llama_stack/cli/stack/configure.py index 56f4feceb..2bb3f7313 100644 --- a/llama_stack/cli/stack/configure.py +++ b/llama_stack/cli/stack/configure.py @@ -17,7 +17,7 @@ class StackConfigure(Subcommand): self.parser = subparsers.add_parser( "configure", prog="llama stack configure", - description="configure a llama stack distribution", + description="Configure a llama stack distribution", formatter_class=argparse.RawTextHelpFormatter, ) self._add_arguments() diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index c32e51fca..73536491b 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -19,7 +19,7 @@ class StackRun(Subcommand): self.parser = subparsers.add_parser( "run", prog="llama stack run", - description="""start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""", + description="""Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""", formatter_class=argparse.RawTextHelpFormatter, ) self._add_arguments() diff --git a/llama_stack/cli/table.py b/llama_stack/cli/table.py index 847719f81..bf59e6103 100644 --- a/llama_stack/cli/table.py +++ b/llama_stack/cli/table.py @@ -4,75 +4,36 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import re -import textwrap from typing import Iterable -from termcolor import cprint - - -def strip_ansi_colors(text): - ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") - return ansi_escape.sub("", text) - - -def format_row(row, col_widths): - def wrap(text, width): - lines = [] - for line in text.split("\n"): - if line.strip() == "": - lines.append("") - else: - lines.extend(textwrap.wrap(line, width, break_long_words=False, replace_whitespace=False)) - return lines - - wrapped = [wrap(item, width) for item, width in zip(row, col_widths)] - max_lines = max(len(subrow) for subrow in wrapped) - - lines = [] - for i in range(max_lines): - line = [] - for cell_lines, width in zip(wrapped, col_widths): - value = cell_lines[i] if i < len(cell_lines) else "" - line.append(value + " " * (width - len(strip_ansi_colors(value)))) - lines.append("| " + (" | ".join(line)) + " |") - - return "\n".join(lines) +from rich.console import Console +from rich.table import Table def print_table(rows, headers=None, separate_rows: bool = False, sort_by: Iterable[int] = tuple()): - def itemlen(item): - return max([len(line) for line in strip_ansi_colors(item).split("\n")]) - + # Convert rows and handle None values rows = [[x or "" for x in row] for row in rows] + # Sort rows if sort_by is specified if sort_by: rows.sort(key=lambda x: tuple(x[i] for i in sort_by)) - if not headers: - col_widths = [max(itemlen(item) for item in col) for col in zip(*rows)] - else: - col_widths = [ - max( - itemlen(header), - max(itemlen(item) for item in col), - ) - for header, col in zip(headers, zip(*rows)) - ] - col_widths = [min(w, 80) for w in col_widths] - - header_line = "+".join("-" * (width + 2) for width in col_widths) - header_line = f"+{header_line}+" + # Create Rich table + table = Table(show_lines=separate_rows) + # Add headers if provided if headers: - print(header_line) - cprint(format_row(headers, col_widths), "white", attrs=["bold"]) + for header in headers: + table.add_column(header, style="bold white") + else: + # Add unnamed columns based on first row + for _ in range(len(rows[0]) if rows else 0): + table.add_column() - print(header_line) + # Add rows for row in rows: - print(format_row(row, col_widths)) - if separate_rows: - print(header_line) + table.add_row(*row) - if not separate_rows: - print(header_line) + # Print table + console = Console() + console.print(table) diff --git a/llama_stack/cli/verify_download.py b/llama_stack/cli/verify_download.py index 47993c361..1229e8601 100644 --- a/llama_stack/cli/verify_download.py +++ b/llama_stack/cli/verify_download.py @@ -44,7 +44,7 @@ def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--model-id", required=True, - help="Model ID to verify", + help="Model ID to verify (only for models downloaded from Meta)", ) parser.set_defaults(func=partial(run_verify_cmd, parser=parser)) diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index 9422c8457..511817de8 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -126,7 +126,6 @@ def build_image( args = [ script, str(image_name), - str(build_file_path), " ".join(normal_deps), ] diff --git a/llama_stack/distribution/build_venv.sh b/llama_stack/distribution/build_venv.sh index 3cb290bb7..0b0bffcfd 100755 --- a/llama_stack/distribution/build_venv.sh +++ b/llama_stack/distribution/build_venv.sh @@ -24,23 +24,21 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then fi if [ "$#" -lt 3 ]; then - echo "Usage: $0 []" >&2 + echo "Usage: $0 []" >&2 echo "Example: $0 mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2 exit 1 fi -special_pip_deps="$4" +special_pip_deps="$3" set -euo pipefail build_name="$1" env_name="llamastack-$build_name" -build_file_path="$2" -pip_dependencies="$3" +pip_dependencies="$2" # Define color codes RED='\033[0;31m' -GREEN='\033[0;32m' NC='\033[0m' # No Color # this is set if we actually create a new conda in which case we need to clean up @@ -49,34 +47,63 @@ ENVNAME="" SCRIPT_DIR=$(dirname "$(readlink -f "$0")") source "$SCRIPT_DIR/common.sh" +# pre-run checks to make sure we can proceed with the installation +pre_run_checks() { + local env_name="$1" + + if ! is_command_available uv; then + echo "uv is not installed, trying to install it." + if ! is_command_available pip; then + echo "pip is not installed, cannot automatically install 'uv'." + echo "Follow this link to install it:" + echo "https://docs.astral.sh/uv/getting-started/installation/" + exit 1 + else + pip install uv + fi + fi + + # checking if an environment with the same name already exists + if [ -d "$env_name" ]; then + echo "Environment '$env_name' already exists, re-using it." + fi +} + run() { local env_name="$1" local pip_dependencies="$2" local special_pip_deps="$3" - pip install uv + echo "Using virtual environment $env_name" + uv venv "$env_name" + # shellcheck source=/dev/null + source "$env_name/bin/activate" if [ -n "$TEST_PYPI_VERSION" ]; then # these packages are damaged in test-pypi, so install them first uv pip install fastapi libcst + # shellcheck disable=SC2086 + # we are building a command line so word splitting is expected uv pip install --extra-index-url https://test.pypi.org/simple/ \ - llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION \ + llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \ $pip_dependencies if [ -n "$special_pip_deps" ]; then IFS='#' read -ra parts <<<"$special_pip_deps" for part in "${parts[@]}"; do echo "$part" + # shellcheck disable=SC2086 + # we are building a command line so word splitting is expected uv pip install $part done fi else - # Re-installing llama-stack in the new conda environment + # Re-installing llama-stack in the new virtual environment if [ -n "$LLAMA_STACK_DIR" ]; then if [ ! -d "$LLAMA_STACK_DIR" ]; then - printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2 + printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2 exit 1 fi - printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n" + printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR" uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR" else uv pip install --no-cache-dir llama-stack @@ -84,26 +111,31 @@ run() { if [ -n "$LLAMA_MODELS_DIR" ]; then if [ ! -d "$LLAMA_MODELS_DIR" ]; then - printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2 + printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_MODELS_DIR" >&2 exit 1 fi - printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n" + printf "Installing from LLAMA_MODELS_DIR: %s\n" "$LLAMA_MODELS_DIR" uv pip uninstall llama-models uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR" fi # Install pip dependencies printf "Installing pip dependencies\n" + # shellcheck disable=SC2086 + # we are building a command line so word splitting is expected uv pip install $pip_dependencies if [ -n "$special_pip_deps" ]; then IFS='#' read -ra parts <<<"$special_pip_deps" for part in "${parts[@]}"; do echo "$part" + # shellcheck disable=SC2086 + # we are building a command line so word splitting is expected uv pip install $part done fi fi } +pre_run_checks "$env_name" run "$env_name" "$pip_dependencies" "$special_pip_deps" diff --git a/llama_stack/distribution/client.py b/llama_stack/distribution/client.py index b1d174ede..1925b864f 100644 --- a/llama_stack/distribution/client.py +++ b/llama_stack/distribution/client.py @@ -186,33 +186,3 @@ def extract_async_iterator_type(type_hint): inner_args = get_args(arg) return inner_args[0] return None - - -async def example(model: str = None): - from llama_stack.apis.inference import Inference, UserMessage # noqa: F403 - from llama_stack.apis.inference.event_logger import EventLogger - - client_class = create_api_client_class(Inference) - client = client_class("http://localhost:5003") - - if not model: - model = "Llama3.2-3B-Instruct" - - message = UserMessage(content="hello world, write me a 2 sentence poem about the moon") - cprint(f"User>{message.content}", "green") - - stream = True - iterator = await client.chat_completion( - model=model, - messages=[message], - stream=stream, - ) - - async for log in EventLogger().log(iterator): - log.print() - - -if __name__ == "__main__": - import asyncio - - asyncio.run(example()) diff --git a/llama_stack/distribution/common.sh b/llama_stack/distribution/common.sh index 963eb395b..171023389 100755 --- a/llama_stack/distribution/common.sh +++ b/llama_stack/distribution/common.sh @@ -38,3 +38,8 @@ setup_cleanup_handlers() { conda deactivate } + +# check if a command is present +is_command_available() { + command -v "$1" &>/dev/null +} diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 97706f22a..f62996081 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -8,10 +8,10 @@ from typing import Annotated, Any, Dict, List, Optional, Union from pydantic import BaseModel, Field +from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Dataset, DatasetInput from llama_stack.apis.eval import Eval -from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput from llama_stack.apis.inference import Inference from llama_stack.apis.models import Model, ModelInput from llama_stack.apis.safety import Safety @@ -37,7 +37,7 @@ RoutableObject = Union[ VectorDB, Dataset, ScoringFn, - EvalTask, + Benchmark, Tool, ToolGroup, ] @@ -50,7 +50,7 @@ RoutableObjectWithProvider = Annotated[ VectorDB, Dataset, ScoringFn, - EvalTask, + Benchmark, Tool, ToolGroup, ], @@ -173,7 +173,7 @@ a default SQLite store will be used.""", vector_dbs: List[VectorDBInput] = Field(default_factory=list) datasets: List[DatasetInput] = Field(default_factory=list) scoring_fns: List[ScoringFnInput] = Field(default_factory=list) - eval_tasks: List[EvalTaskInput] = Field(default_factory=list) + benchmarks: List[BenchmarkInput] = Field(default_factory=list) tool_groups: List[ToolGroupInput] = Field(default_factory=list) server: ServerConfig = Field( diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 2dcf38463..384e2c3c8 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -44,7 +44,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: router_api=Api.scoring, ), AutoRoutedApiInfo( - routing_table_api=Api.eval_tasks, + routing_table_api=Api.benchmarks, router_api=Api.eval, ), AutoRoutedApiInfo( diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 55a15e5e9..639e5ee73 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -13,7 +13,7 @@ import re from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path -from typing import Any, Optional, TypeVar, get_args, get_origin +from typing import Any, Optional, TypeVar, Union, get_args, get_origin import httpx import yaml @@ -47,6 +47,8 @@ from llama_stack.providers.utils.telemetry.tracing import ( start_trace, ) +logger = logging.getLogger(__name__) + T = TypeVar("T") @@ -81,12 +83,13 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any: return value origin = get_origin(annotation) + if origin is list: item_type = get_args(annotation)[0] try: return [convert_to_pydantic(item_type, item) for item in value] except Exception: - print(f"Error converting list {value}") + logger.error(f"Error converting list {value} into {item_type}") return value elif origin is dict: @@ -94,17 +97,25 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any: try: return {k: convert_to_pydantic(val_type, v) for k, v in value.items()} except Exception: - print(f"Error converting dict {value}") + logger.error(f"Error converting dict {value} into {val_type}") return value try: # Handle Pydantic models and discriminated unions return TypeAdapter(annotation).validate_python(value) + except Exception as e: - cprint( - f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", - "yellow", - ) + # TODO: this is workardound for having Union[str, AgentToolGroup] in API schema. + # We should get rid of any non-discriminated unions in the API schema. + if origin is Union: + for union_type in get_args(annotation): + try: + return convert_to_pydantic(union_type, value) + except Exception: + continue + logger.warning( + f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", + ) return value @@ -142,7 +153,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): for handler in root_logger.handlers[:]: root_logger.removeHandler(handler) - print(f"Removed handler {handler.__class__.__name__} from root logger") + logger.info(f"Removed handler {handler.__class__.__name__} from root logger") def request(self, *args, **kwargs): if kwargs.get("stream"): @@ -231,7 +242,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): def _convert_path_to_regex(path: str) -> str: # Convert {param} to named capture groups - pattern = re.sub(r"{(\w+)}", r"(?P<\1>[^/]+)", path) + # handle {param:path} as well which allows for forward slashes in the param value + pattern = re.sub( + r"{(\w+)(?::path)?}", + lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})", + path, + ) + return f"^{pattern}$" for api, api_endpoints in endpoints.items(): @@ -415,4 +432,5 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if param_name in body: value = body.get(param_name) converted_body[param_name] = convert_to_pydantic(param.annotation, value) + return converted_body diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 353c2971b..0bc2e774c 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -9,10 +9,10 @@ import logging from typing import Any, Dict, List, Set from llama_stack.apis.agents import Agents +from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval import Eval -from llama_stack.apis.eval_tasks import EvalTasks from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.models import Models @@ -37,8 +37,8 @@ from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.providers.datatypes import ( Api, + BenchmarksProtocolPrivate, DatasetsProtocolPrivate, - EvalTasksProtocolPrivate, InlineProviderSpec, ModelsProtocolPrivate, ProviderSpec, @@ -73,7 +73,7 @@ def api_protocol_map() -> Dict[Api, Any]: Api.scoring: Scoring, Api.scoring_functions: ScoringFunctions, Api.eval: Eval, - Api.eval_tasks: EvalTasks, + Api.benchmarks: Benchmarks, Api.post_training: PostTraining, Api.tool_groups: ToolGroups, Api.tool_runtime: ToolRuntime, @@ -92,7 +92,7 @@ def additional_protocols_map() -> Dict[Api, Any]: ScoringFunctions, Api.scoring_functions, ), - Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks), + Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks), } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 18197ca7f..a54f57fb3 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -11,8 +11,8 @@ from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable from .routing_tables import ( + BenchmarksRoutingTable, DatasetsRoutingTable, - EvalTasksRoutingTable, ModelsRoutingTable, ScoringFunctionsRoutingTable, ShieldsRoutingTable, @@ -33,7 +33,7 @@ async def get_routing_table_impl( "shields": ShieldsRoutingTable, "datasets": DatasetsRoutingTable, "scoring_functions": ScoringFunctionsRoutingTable, - "eval_tasks": EvalTasksRoutingTable, + "benchmarks": BenchmarksRoutingTable, "tool_groups": ToolGroupsRoutingTable, } diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index e716e44b0..9d12c8a40 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -9,9 +9,8 @@ from typing import Any, AsyncGenerator, Dict, List, Optional from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.eval import ( - AppEvalTaskConfig, + BenchmarkConfig, Eval, - EvalTaskConfig, EvaluateResponse, Job, JobStatus, @@ -129,7 +128,7 @@ class InferenceRouter(Inference): sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_choice: Optional[ToolChoice] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, @@ -141,20 +140,36 @@ class InferenceRouter(Inference): if model.model_type == ModelType.embedding: raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") if tool_config: - if tool_choice != tool_config.tool_choice: + if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") - if tool_prompt_format != tool_config.tool_prompt_format: + if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") else: - tool_config = ToolConfig( - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, - ) + params = {} + if tool_choice: + params["tool_choice"] = tool_choice + if tool_prompt_format: + params["tool_prompt_format"] = tool_prompt_format + tool_config = ToolConfig(**params) + + tools = tools or [] + if tool_config.tool_choice == ToolChoice.none: + tools = [] + elif tool_config.tool_choice == ToolChoice.auto: + pass + elif tool_config.tool_choice == ToolChoice.required: + pass + else: + # verify tool_choice is one of the tools + tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] + if tool_config.tool_choice not in tool_names: + raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") + params = dict( model_id=model_id, messages=messages, sampling_params=sampling_params, - tools=tools or [], + tools=tools, tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, response_format=response_format, @@ -347,23 +362,23 @@ class EvalRouter(Eval): async def run_eval( self, - task_id: str, - task_config: AppEvalTaskConfig, + benchmark_id: str, + task_config: BenchmarkConfig, ) -> Job: - return await self.routing_table.get_provider_impl(task_id).run_eval( - task_id=task_id, + return await self.routing_table.get_provider_impl(benchmark_id).run_eval( + benchmark_id=benchmark_id, task_config=task_config, ) async def evaluate_rows( self, - task_id: str, + benchmark_id: str, input_rows: List[Dict[str, Any]], scoring_functions: List[str], - task_config: EvalTaskConfig, + task_config: BenchmarkConfig, ) -> EvaluateResponse: - return await self.routing_table.get_provider_impl(task_id).evaluate_rows( - task_id=task_id, + return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( + benchmark_id=benchmark_id, input_rows=input_rows, scoring_functions=scoring_functions, task_config=task_config, @@ -371,30 +386,72 @@ class EvalRouter(Eval): async def job_status( self, - task_id: str, + benchmark_id: str, job_id: str, ) -> Optional[JobStatus]: - return await self.routing_table.get_provider_impl(task_id).job_status(task_id, job_id) + return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) async def job_cancel( self, - task_id: str, + benchmark_id: str, job_id: str, ) -> None: - await self.routing_table.get_provider_impl(task_id).job_cancel( - task_id, + await self.routing_table.get_provider_impl(benchmark_id).job_cancel( + benchmark_id, job_id, ) async def job_result( + self, + benchmark_id: str, + job_id: str, + ) -> EvaluateResponse: + return await self.routing_table.get_provider_impl(benchmark_id).job_result( + benchmark_id, + job_id, + ) + + async def DEPRECATED_run_eval( + self, + task_id: str, + task_config: BenchmarkConfig, + ) -> Job: + return await self.run_eval(benchmark_id=task_id, task_config=task_config) + + async def DEPRECATED_evaluate_rows( + self, + task_id: str, + input_rows: List[Dict[str, Any]], + scoring_functions: List[str], + task_config: BenchmarkConfig, + ) -> EvaluateResponse: + return await self.evaluate_rows( + benchmark_id=task_id, + input_rows=input_rows, + scoring_functions=scoring_functions, + task_config=task_config, + ) + + async def DEPRECATED_job_status( + self, + task_id: str, + job_id: str, + ) -> Optional[JobStatus]: + return await self.job_status(benchmark_id=task_id, job_id=job_id) + + async def DEPRECATED_job_cancel( + self, + task_id: str, + job_id: str, + ) -> None: + return await self.job_cancel(benchmark_id=task_id, job_id=job_id) + + async def DEPRECATED_job_result( self, task_id: str, job_id: str, ) -> EvaluateResponse: - return await self.routing_table.get_provider_impl(task_id).job_result( - task_id, - job_id, - ) + return await self.job_result(benchmark_id=task_id, job_id=job_id) class ToolRuntimeRouter(ToolRuntime): diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 009775ca5..2cddc3970 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -4,14 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging from typing import Any, Dict, List, Optional from pydantic import TypeAdapter +from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse -from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ( @@ -38,6 +39,8 @@ from llama_stack.distribution.datatypes import ( from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable +logger = logging.getLogger(__name__) + def get_impl_api(p: Any) -> Api: return p.__provider_spec__.api @@ -60,7 +63,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable elif api == Api.scoring: return await p.register_scoring_function(obj) elif api == Api.eval: - return await p.register_eval_task(obj) + return await p.register_benchmark(obj) elif api == Api.tool_runtime: return await p.register_tool(obj) else: @@ -121,7 +124,7 @@ class CommonRoutingTableImpl(RoutingTable): scoring_functions = await p.list_scoring_functions() await add_objects(scoring_functions, pid, ScoringFn) elif api == Api.eval: - p.eval_task_store = self + p.benchmark_store = self elif api == Api.tool_runtime: p.tool_store = self @@ -141,8 +144,8 @@ class CommonRoutingTableImpl(RoutingTable): return ("DatasetIO", "dataset") elif isinstance(self, ScoringFunctionsRoutingTable): return ("Scoring", "scoring_function") - elif isinstance(self, EvalTasksRoutingTable): - return ("Eval", "eval_task") + elif isinstance(self, BenchmarksRoutingTable): + return ("Eval", "benchmark") elif isinstance(self, ToolGroupsRoutingTable): return ("Tools", "tool") else: @@ -428,20 +431,20 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): await self.register_object(scoring_fn) -class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): - async def list_eval_tasks(self) -> ListEvalTasksResponse: - return ListEvalTasksResponse(data=await self.get_all_with_type("eval_task")) +class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): + async def list_benchmarks(self) -> ListBenchmarksResponse: + return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) - async def get_eval_task(self, eval_task_id: str) -> Optional[EvalTask]: - return await self.get_object_by_identifier("eval_task", eval_task_id) + async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]: + return await self.get_object_by_identifier("benchmark", benchmark_id) - async def register_eval_task( + async def register_benchmark( self, - eval_task_id: str, + benchmark_id: str, dataset_id: str, scoring_functions: List[str], metadata: Optional[Dict[str, Any]] = None, - provider_eval_task_id: Optional[str] = None, + provider_benchmark_id: Optional[str] = None, provider_id: Optional[str] = None, ) -> None: if metadata is None: @@ -453,17 +456,46 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." ) - if provider_eval_task_id is None: - provider_eval_task_id = eval_task_id - eval_task = EvalTask( - identifier=eval_task_id, + if provider_benchmark_id is None: + provider_benchmark_id = benchmark_id + benchmark = Benchmark( + identifier=benchmark_id, dataset_id=dataset_id, scoring_functions=scoring_functions, metadata=metadata, provider_id=provider_id, - provider_resource_id=provider_eval_task_id, + provider_resource_id=provider_benchmark_id, + ) + await self.register_object(benchmark) + + async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse: + logger.warning("DEPRECATED: Use /eval/benchmarks instead") + return await self.list_benchmarks() + + async def DEPRECATED_get_eval_task( + self, + eval_task_id: str, + ) -> Optional[Benchmark]: + logger.warning("DEPRECATED: Use /eval/benchmarks instead") + return await self.get_benchmark(eval_task_id) + + async def DEPRECATED_register_eval_task( + self, + eval_task_id: str, + dataset_id: str, + scoring_functions: List[str], + provider_benchmark_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + logger.warning("DEPRECATED: Use /eval/benchmarks instead") + return await self.register_benchmark( + benchmark_id=eval_task_id, + dataset_id=dataset_id, + scoring_functions=scoring_functions, + metadata=metadata, + provider_benchmark_id=provider_benchmark_id, ) - await self.register_object(eval_task) class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 2baad8ac4..9335dc3a9 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -15,10 +15,10 @@ from termcolor import colored from llama_stack.apis.agents import Agents from llama_stack.apis.batch_inference import BatchInference +from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval import Eval -from llama_stack.apis.eval_tasks import EvalTasks from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.models import Models @@ -53,7 +53,7 @@ class LlamaStack( PostTraining, VectorIO, Eval, - EvalTasks, + Benchmarks, Scoring, ScoringFunctions, DatasetIO, @@ -78,7 +78,7 @@ RESOURCES = [ "register_scoring_function", "list_scoring_functions", ), - ("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"), + ("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"), ("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"), ] diff --git a/llama_stack/distribution/ui/README.md b/llama_stack/distribution/ui/README.md index c0a2597af..8fceb5c63 100644 --- a/llama_stack/distribution/ui/README.md +++ b/llama_stack/distribution/ui/README.md @@ -26,7 +26,7 @@ $ llama-stack-client datasets register \ ``` ```bash -$ llama-stack-client eval_tasks register \ +$ llama-stack-client benchmarks register \ --eval-task-id meta-reference-mmlu \ --provider-id meta-reference \ --dataset-id mmlu \ diff --git a/llama_stack/distribution/ui/page/distribution/eval_tasks.py b/llama_stack/distribution/ui/page/distribution/eval_tasks.py index f58969663..1428ae9ab 100644 --- a/llama_stack/distribution/ui/page/distribution/eval_tasks.py +++ b/llama_stack/distribution/ui/page/distribution/eval_tasks.py @@ -8,12 +8,12 @@ import streamlit as st from modules.api import llama_stack_api -def eval_tasks(): - # Eval Tasks Section - st.header("Eval Tasks") +def benchmarks(): + # Benchmarks Section + st.header("Benchmarks") - eval_tasks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()} + benchmarks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.benchmarks.list()} - if len(eval_tasks_info) > 0: - selected_eval_task = st.selectbox("Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect") - st.json(eval_tasks_info[selected_eval_task], expanded=True) + if len(benchmarks_info) > 0: + selected_benchmark = st.selectbox("Select an eval task", list(benchmarks_info.keys()), key="benchmark_inspect") + st.json(benchmarks_info[selected_benchmark], expanded=True) diff --git a/llama_stack/distribution/ui/page/distribution/resources.py b/llama_stack/distribution/ui/page/distribution/resources.py index 94b840bcb..684270d4d 100644 --- a/llama_stack/distribution/ui/page/distribution/resources.py +++ b/llama_stack/distribution/ui/page/distribution/resources.py @@ -4,8 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from page.distribution.benchmarks import benchmarks from page.distribution.datasets import datasets -from page.distribution.eval_tasks import eval_tasks from page.distribution.models import models from page.distribution.scoring_functions import scoring_functions from page.distribution.shields import shields @@ -20,7 +20,7 @@ def resources_page(): "Shields", "Scoring Functions", "Datasets", - "Eval Tasks", + "Benchmarks", ] icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"] selected_resource = option_menu( @@ -34,8 +34,8 @@ def resources_page(): }, }, ) - if selected_resource == "Eval Tasks": - eval_tasks() + if selected_resource == "Benchmarks": + benchmarks() elif selected_resource == "Vector Databases": vector_dbs() elif selected_resource == "Datasets": diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/distribution/ui/page/evaluations/native_eval.py index 112d9cff0..f1cae714a 100644 --- a/llama_stack/distribution/ui/page/evaluations/native_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/native_eval.py @@ -11,28 +11,28 @@ import streamlit as st from modules.api import llama_stack_api -def select_eval_task_1(): - # Select Eval Tasks +def select_benchmark_1(): + # Select Benchmarks st.subheader("1. Choose An Eval Task") - eval_tasks = llama_stack_api.client.eval_tasks.list() - eval_tasks = {et.identifier: et for et in eval_tasks} - eval_tasks_names = list(eval_tasks.keys()) - selected_eval_task = st.selectbox( + benchmarks = llama_stack_api.client.benchmarks.list() + benchmarks = {et.identifier: et for et in benchmarks} + benchmarks_names = list(benchmarks.keys()) + selected_benchmark = st.selectbox( "Choose an eval task.", - options=eval_tasks_names, + options=benchmarks_names, help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.", ) with st.expander("View Eval Task"): - st.json(eval_tasks[selected_eval_task], expanded=True) + st.json(benchmarks[selected_benchmark], expanded=True) - st.session_state["selected_eval_task"] = selected_eval_task - st.session_state["eval_tasks"] = eval_tasks + st.session_state["selected_benchmark"] = selected_benchmark + st.session_state["benchmarks"] = benchmarks if st.button("Confirm", key="confirm_1"): - st.session_state["selected_eval_task_1_next"] = True + st.session_state["selected_benchmark_1_next"] = True def define_eval_candidate_2(): - if not st.session_state.get("selected_eval_task_1_next", None): + if not st.session_state.get("selected_benchmark_1_next", None): return st.subheader("2. Define Eval Candidate") @@ -161,11 +161,11 @@ def run_evaluation_3(): Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button. """ ) - selected_eval_task = st.session_state["selected_eval_task"] - eval_tasks = st.session_state["eval_tasks"] + selected_benchmark = st.session_state["selected_benchmark"] + benchmarks = st.session_state["benchmarks"] eval_candidate = st.session_state["eval_candidate"] - dataset_id = eval_tasks[selected_eval_task].dataset_id + dataset_id = benchmarks[selected_benchmark].dataset_id rows = llama_stack_api.client.datasetio.get_rows_paginated( dataset_id=dataset_id, rows_in_page=-1, @@ -180,16 +180,16 @@ def run_evaluation_3(): help="Number of examples from the dataset to evaluate. ", ) - eval_task_config = { + benchmark_config = { "type": "benchmark", "eval_candidate": eval_candidate, "scoring_params": {}, } with st.expander("View Evaluation Task", expanded=True): - st.json(eval_tasks[selected_eval_task], expanded=True) + st.json(benchmarks[selected_benchmark], expanded=True) with st.expander("View Evaluation Task Configuration", expanded=True): - st.json(eval_task_config, expanded=True) + st.json(benchmark_config, expanded=True) # Add run button and handle evaluation if st.button("Run Evaluation"): @@ -209,10 +209,10 @@ def run_evaluation_3(): progress_bar.progress(progress, text=progress_text) # Run evaluation for current row eval_res = llama_stack_api.client.eval.evaluate_rows( - task_id=selected_eval_task, + benchmark_id=selected_benchmark, input_rows=[r], - scoring_functions=eval_tasks[selected_eval_task].scoring_functions, - task_config=eval_task_config, + scoring_functions=benchmarks[selected_benchmark].scoring_functions, + task_config=benchmark_config, ) for k in r.keys(): @@ -225,7 +225,7 @@ def run_evaluation_3(): output_res[k] = [] output_res[k].append(eval_res.generations[0][k]) - for scoring_fn in eval_tasks[selected_eval_task].scoring_functions: + for scoring_fn in benchmarks[selected_benchmark].scoring_functions: if scoring_fn not in output_res: output_res[scoring_fn] = [] output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0]) @@ -245,7 +245,7 @@ def native_evaluation_page(): st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙") st.title("📊 Evaluations (Generation + Scoring)") - select_eval_task_1() + select_benchmark_1() define_eval_candidate_2() run_evaluation_3() diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py new file mode 100644 index 000000000..a5dc9ac4a --- /dev/null +++ b/llama_stack/models/llama/datatypes.py @@ -0,0 +1,277 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from enum import Enum +from typing import Any, Dict, Literal, Optional, Union + +# import all for backwards compatibility +from llama_models.datatypes import * # noqa: F403 +from pydantic import BaseModel, ConfigDict, Field, field_validator +from typing_extensions import Annotated + +from llama_stack.schema_utils import json_schema_type, register_schema + +register_schema(ToolCall) + + +@json_schema_type +class ToolParamDefinition(BaseModel): + param_type: str + description: Optional[str] = None + required: Optional[bool] = True + default: Optional[Any] = None + + +@json_schema_type +class ToolDefinition(BaseModel): + tool_name: Union[BuiltinTool, str] + description: Optional[str] = None + parameters: Optional[Dict[str, ToolParamDefinition]] = None + + @field_validator("tool_name", mode="before") + @classmethod + def validate_field(cls, v): + if isinstance(v, str): + try: + return BuiltinTool(v) + except ValueError: + return v + return v + + +@json_schema_type +class GreedySamplingStrategy(BaseModel): + type: Literal["greedy"] = "greedy" + + +@json_schema_type +class TopPSamplingStrategy(BaseModel): + type: Literal["top_p"] = "top_p" + temperature: Optional[float] = Field(..., gt=0.0) + top_p: Optional[float] = 0.95 + + +@json_schema_type +class TopKSamplingStrategy(BaseModel): + type: Literal["top_k"] = "top_k" + top_k: int = Field(..., ge=1) + + +SamplingStrategy = register_schema( + Annotated[ + Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], + Field(discriminator="type"), + ], + name="SamplingStrategy", +) + + +@json_schema_type +class SamplingParams(BaseModel): + strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy) + + max_tokens: Optional[int] = 0 + repetition_penalty: Optional[float] = 1.0 + + +class CheckpointQuantizationFormat(Enum): + # default format + bf16 = "bf16" + + # used for enabling fp8_rowwise inference, some weights are bf16 + fp8_mixed = "fp8-mixed" + + int8 = "int8" + + int4 = "int4" + + +class ModelFamily(Enum): + llama2 = "llama2" + llama3 = "llama3" + llama3_1 = "llama3_1" + llama3_2 = "llama3_2" + llama3_3 = "llama3_3" + safety = "safety" + + +class CoreModelId(Enum): + """Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)""" + + # Llama 2 family + llama2_7b = "Llama-2-7b" + llama2_13b = "Llama-2-13b" + llama2_70b = "Llama-2-70b" + llama2_7b_chat = "Llama-2-7b-chat" + llama2_13b_chat = "Llama-2-13b-chat" + llama2_70b_chat = "Llama-2-70b-chat" + + # Llama 3 family + llama3_8b = "Llama-3-8B" + llama3_70b = "Llama-3-70B" + llama3_8b_instruct = "Llama-3-8B-Instruct" + llama3_70b_instruct = "Llama-3-70B-Instruct" + + # Llama 3.1 family + llama3_1_8b = "Llama3.1-8B" + llama3_1_70b = "Llama3.1-70B" + llama3_1_405b = "Llama3.1-405B" + llama3_1_8b_instruct = "Llama3.1-8B-Instruct" + llama3_1_70b_instruct = "Llama3.1-70B-Instruct" + llama3_1_405b_instruct = "Llama3.1-405B-Instruct" + + # Llama 3.2 family + llama3_2_1b = "Llama3.2-1B" + llama3_2_3b = "Llama3.2-3B" + llama3_2_1b_instruct = "Llama3.2-1B-Instruct" + llama3_2_3b_instruct = "Llama3.2-3B-Instruct" + llama3_2_11b_vision = "Llama3.2-11B-Vision" + llama3_2_90b_vision = "Llama3.2-90B-Vision" + llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct" + llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct" + + # Llama 3.3 family + llama3_3_70b_instruct = "Llama3.3-70B-Instruct" + + # Safety models + llama_guard_3_8b = "Llama-Guard-3-8B" + llama_guard_2_8b = "Llama-Guard-2-8B" + llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision" + llama_guard_3_1b = "Llama-Guard-3-1B" + + +def is_multimodal(model_id) -> bool: + if model_id in [ + CoreModelId.llama3_2_11b_vision, + CoreModelId.llama3_2_90b_vision, + CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct, + ]: + return True + else: + return False + + +def model_family(model_id) -> ModelFamily: + if model_id in [ + CoreModelId.llama2_7b, + CoreModelId.llama2_13b, + CoreModelId.llama2_70b, + CoreModelId.llama2_7b_chat, + CoreModelId.llama2_13b_chat, + CoreModelId.llama2_70b_chat, + ]: + return ModelFamily.llama2 + elif model_id in [ + CoreModelId.llama3_8b, + CoreModelId.llama3_70b, + CoreModelId.llama3_8b_instruct, + CoreModelId.llama3_70b_instruct, + ]: + return ModelFamily.llama3 + elif model_id in [ + CoreModelId.llama3_1_8b, + CoreModelId.llama3_1_70b, + CoreModelId.llama3_1_405b, + CoreModelId.llama3_1_8b_instruct, + CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_405b_instruct, + ]: + return ModelFamily.llama3_1 + elif model_id in [ + CoreModelId.llama3_2_1b, + CoreModelId.llama3_2_3b, + CoreModelId.llama3_2_1b_instruct, + CoreModelId.llama3_2_3b_instruct, + CoreModelId.llama3_2_11b_vision, + CoreModelId.llama3_2_90b_vision, + CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct, + ]: + return ModelFamily.llama3_2 + elif model_id in [ + CoreModelId.llama3_3_70b_instruct, + ]: + return ModelFamily.llama3_3 + elif model_id in [ + CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_2_8b, + CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_1b, + ]: + return ModelFamily.safety + else: + raise ValueError(f"Unknown model family for {model_id}") + + +class Model(BaseModel): + core_model_id: CoreModelId + description: str + huggingface_repo: Optional[str] = None + recommended_sampling_params: Optional[SamplingParams] = None + arch_args: Dict[str, Any] + variant: str = "" + + quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 + pth_file_count: int + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + + # silence pydantic until we remove the `model_` fields + model_config = ConfigDict(protected_namespaces=()) + + @property + def model_family(self) -> ModelFamily: + return model_family(self.core_model_id) + + # The SKU is uniquely identified by (model_id, variant) combo + def descriptor(self, shorten_default_variant: bool = True) -> str: + if not self.variant: + return self.core_model_id.value + return f"{self.core_model_id.value}:{self.variant}" + + @property + def is_instruct_model(self) -> bool: + return "instruct" in self.id.name + + # Featured models are shown in the non-exhaustive model list + @property + def is_featured(self) -> bool: + return self.model_family in [ + ModelFamily.llama3_1, + ModelFamily.llama3_2, + ModelFamily.llama3_3, + ModelFamily.safety, + ] + + @property + def max_seq_length(self) -> int: + if self.model_family == ModelFamily.llama2: + return 4096 + elif self.core_model_id == CoreModelId.llama_guard_2_8b: + return 4096 + elif self.model_family == ModelFamily.llama3: + return 8192 + elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]: + return 131072 + elif self.model_family == ModelFamily.llama3_2: + if self.quantization_format == CheckpointQuantizationFormat.int4: + return 8192 + return 131072 + elif self.core_model_id in [ + CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_1b, + ]: + return 131072 + else: + raise ValueError(f"Unknown max_seq_len for {self.core_model_id}") diff --git a/llama_stack/models/llama/llama3/dog.jpg b/llama_stack/models/llama/llama3/dog.jpg new file mode 100644 index 000000000..f9a3a8057 Binary files /dev/null and b/llama_stack/models/llama/llama3/dog.jpg differ diff --git a/llama_stack/models/llama/llama3/interface.py b/llama_stack/models/llama/llama3/interface.py new file mode 100644 index 000000000..bc42228a5 --- /dev/null +++ b/llama_stack/models/llama/llama3/interface.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from pathlib import Path +from typing import List, Optional + +from llama_models.datatypes import ( + BuiltinTool, + RawMessage, + StopReason, + ToolCall, + ToolPromptFormat, +) +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.tokenizer import Tokenizer +from termcolor import colored + +from llama_stack.models.llama.datatypes import ToolDefinition + +from . import template_data +from .prompt_templates import ( + BuiltinToolGenerator, + FunctionTagCustomToolGenerator, + JsonCustomToolGenerator, + SystemDefaultGenerator, + ToolResponseGenerator, +) + +THIS_DIR = Path(__file__).parent + + +class Template: + def __init__( + self, + role, + template_name, + data_provider=None, + notes=None, + ): + self.role = role + self.template_name = template_name + self.data_provider = data_provider or "" + self._notes = notes or "" + + @property + def notes(self): + default = "↵ represents newline" + notes = default + if self._notes: + notes += "\n" + notes += self._notes + return notes + + +TEMPLATES = [ + Template( + "user", + "user-default", + "user_default", + ), + Template( + "user", + "user-images", + "user_images", + ), + Template("user", "user-interleaved-images", "user_interleaved_images"), + Template( + "assistant", + "assistant-builtin-tool-call", + "assistant_builtin_tool_call", + "Notice <|python_tag|>", + ), + Template( + "assistant", + "assistant-custom-tool-call", + "assistant_custom_tool_call", + "Notice format", + ), + Template( + "assistant", + "assistant-default", + "assistant_default", + ), + Template( + "system", + "system-builtin-and-custom-tools", + "system_message_builtin_and_custom_tools", + ), + Template( + "system", + "system-builtin-tools-only", + "system_message_builtin_tools_only", + ), + Template( + "system", + "system-custom-tools-only", + "system_message_custom_tools_only", + ), + Template( + "system", + "system-default", + "system_default", + ), + Template( + "tool", + "tool-success", + "tool_success", + "Note ipython header and [stdout]", + ), + Template( + "tool", + "tool-failure", + "tool_failure", + "Note ipython header and [stderr]", + ), +] + + +class LLama31Interface: + def __init__(self, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json): + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) + self.tool_prompt_format = tool_prompt_format + + def get_tokens(self, messages: List[RawMessage]) -> List[int]: + model_input = self.formatter.encode_dialog_prompt( + messages, + self.tool_prompt_format, + ) + return model_input.tokens + + def tool_response_messages(self, *args, **kwargs): + template = ToolResponseGenerator().gen(*args, **kwargs) + return [ + RawMessage( + role="tool", + content=template.render(), + ) + ] + + def system_messages( + self, + builtin_tools: List[BuiltinTool], + custom_tools: List[ToolDefinition], + instruction: Optional[str] = None, + ) -> List[RawMessage]: + messages = [] + + default_gen = SystemDefaultGenerator() + default_template = default_gen.gen() + + sys_content = "" + + tool_template = None + if builtin_tools or custom_tools: + tool_gen = BuiltinToolGenerator() + tool_template = tool_gen.gen(builtin_tools + custom_tools) + + sys_content += tool_template.render() + sys_content += "\n" + + sys_content += default_template.render() + + if instruction: + sys_content += "\n\n" + sys_content += instruction + + sys_content += "\n" + messages.append(RawMessage(role="system", content=sys_content)) + + if custom_tools: + if self.tool_prompt_format == ToolPromptFormat.json: + tool_gen = JsonCustomToolGenerator() + elif self.tool_prompt_format == ToolPromptFormat.function_tag: + tool_gen = FunctionTagCustomToolGenerator() + else: + raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}") + + custom_template = tool_gen.gen(custom_tools) + messages.append(RawMessage(role="user", content=custom_template.render())) + + return messages + + def assistant_response_messages( + self, + content: str, + stop_reason: StopReason, + tool_call: Optional[ToolCall] = None, + ) -> List[RawMessage]: + tool_calls = [] + if tool_call: + tool_calls.append(tool_call) + return [ + RawMessage( + role="assistant", + content=content, + tool_calls=tool_calls, + stop_reason=stop_reason, + ) + ] + + def user_message(self, content: str) -> List[RawMessage]: + return [RawMessage(role="user", content=content)] + + def display_message_as_tokens(self, message: RawMessage) -> None: + """Util to print tokenized string to shell""" + tokens = self.formatter.encode_message(message, self.tool_prompt_format) + on_colors = [ + "on_red", + "on_green", + "on_yellow", + "on_blue", + "on_magenta", + "on_cyan", + ] + for i, t in enumerate(tokens): + on_col = on_colors[i % len(on_colors)] + print(colored(self.tokenizer.decode([t]), "white", on_col), end="") + print("\n", end="") + + +def list_jinja_templates() -> List[Template]: + return TEMPLATES + + +def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat): + by_name = {t.template_name: t for t in TEMPLATES} + if name not in by_name: + raise ValueError(f"No template found for `{name}`") + + template = by_name[name] + interface = LLama31Interface(tool_prompt_format) + + data_func = getattr(template_data, template.data_provider) + if template.role == "system": + messages = interface.system_messages(**data_func()) + elif template.role == "tool": + messages = interface.tool_response_messages(**data_func()) + elif template.role == "assistant": + messages = interface.assistant_response_messages(**data_func()) + elif template.role == "user": + messages = interface.user_message(**data_func()) + + tokens = interface.get_tokens(messages) + special_tokens = list(interface.tokenizer.special_tokens.values()) + tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens] + return template, tokens diff --git a/llama_stack/models/llama/llama3/pasta.jpeg b/llama_stack/models/llama/llama3/pasta.jpeg new file mode 100644 index 000000000..e8299321c Binary files /dev/null and b/llama_stack/models/llama/llama3/pasta.jpeg differ diff --git a/llama_stack/models/llama/llama3/prompt_templates/__init__.py b/llama_stack/models/llama/llama3/prompt_templates/__init__.py new file mode 100644 index 000000000..4eed54d12 --- /dev/null +++ b/llama_stack/models/llama/llama3/prompt_templates/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from .base import PromptTemplate, PromptTemplateGeneratorBase # noqa: F401 +from .system_prompts import ( # noqa: F401 + BuiltinToolGenerator, + FunctionTagCustomToolGenerator, + JsonCustomToolGenerator, + PythonListCustomToolGenerator, + SystemDefaultGenerator, +) +from .tool_response import ToolResponseGenerator # noqa: F401 diff --git a/llama_stack/models/llama/llama3/prompt_templates/base.py b/llama_stack/models/llama/llama3/prompt_templates/base.py new file mode 100644 index 000000000..bff2a21e1 --- /dev/null +++ b/llama_stack/models/llama/llama3/prompt_templates/base.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from dataclasses import dataclass +from typing import Any, Dict, List + +from jinja2 import Template + + +@dataclass +class PromptTemplate: + template: str + data: Dict[str, Any] + + def render(self): + template = Template(self.template) + return template.render(self.data) + + +class PromptTemplateGeneratorBase: + """ + Base class for prompt template generators. + """ + + def gen(self, *args, **kwargs) -> PromptTemplate: + raise NotImplementedError() + + def data_examples(self) -> List[Any]: + raise NotImplementedError() diff --git a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py new file mode 100644 index 000000000..27b1a3502 --- /dev/null +++ b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from datetime import datetime +from typing import Any, List, Optional + +from llama_models.datatypes import ( + BuiltinTool, +) + +from llama_stack.models.llama.datatypes import ( + ToolDefinition, + ToolParamDefinition, +) + +from .base import PromptTemplate, PromptTemplateGeneratorBase + + +class SystemDefaultGenerator(PromptTemplateGeneratorBase): + def gen(self, *args, **kwargs) -> PromptTemplate: + template_str = textwrap.dedent( + """ + Cutting Knowledge Date: December 2023 + Today Date: {{ today }} + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + {"today": datetime.now().strftime("%d %B %Y")}, + ) + + def data_examples(self) -> List[Any]: + return [None] + + +class BuiltinToolGenerator(PromptTemplateGeneratorBase): + def _tool_breakdown(self, tools: List[ToolDefinition]): + builtin_tools, custom_tools = [], [] + for dfn in tools: + if isinstance(dfn.tool_name, BuiltinTool): + builtin_tools.append(dfn) + else: + custom_tools.append(dfn) + + return builtin_tools, custom_tools + + def gen(self, tools: List[ToolDefinition]) -> PromptTemplate: + builtin_tools, custom_tools = self._tool_breakdown(tools) + template_str = textwrap.dedent( + """ + {% if builtin_tools or custom_tools -%} + Environment: ipython + {% endif -%} + {% set builtin_tools = builtin_tools | reject('equalto', 'code_interpreter') | list -%} + {% if builtin_tools -%} + Tools: {{ builtin_tools | join(", ") | trim -}} + {% endif %} + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + { + "builtin_tools": [t.tool_name.value for t in builtin_tools], + "custom_tools": custom_tools, + }, + ) + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + # builtin tools + [ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition(tool_name=BuiltinTool.brave_search), + ToolDefinition(tool_name=BuiltinTool.wolfram_alpha), + ], + # only code interpretor + [ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ], + ] + + +class JsonCustomToolGenerator(PromptTemplateGeneratorBase): + def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + Answer the user's question by making use of the following functions if needed. + If none of the function can be used, please say so. + Here is a list of functions in JSON format: + {% for t in custom_tools -%} + {# manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set tparams = t.parameters -%} + {%- set required_params = [] -%} + {%- for name, param in tparams.items() if param.required == true -%} + {%- set _ = required_params.append(name) -%} + {%- endfor -%} + { + "type": "function", + "function": { + "name": "{{tname}}", + "description": "{{tdesc}}", + "parameters": { + "type": "object", + "properties": [ + {%- for name, param in tparams.items() %} + { + "{{name}}": { + "type": "object", + "description": "{{param.description}}" + } + }{% if not loop.last %},{% endif %} + {%- endfor %} + ], + "required": {{ required_params | tojson }} + } + } + } + {% endfor %} + Return function calls in JSON format. + """ + ) + + return PromptTemplate( + template_str.lstrip("\n"), + {"custom_tools": [t.model_dump() for t in custom_tools]}, + ) + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="trending_songs", + description="Returns the trending songs on a Music site", + parameters={ + "n": ToolParamDefinition( + param_type="int", + description="The number of songs to return", + required=True, + ), + "genre": ToolParamDefinition( + param_type="str", + description="The genre of the songs to return", + required=False, + ), + }, + ), + ] + ] + + +class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase): + def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + You have access to the following functions: + + {% for t in custom_tools %} + {#- manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set modified_params = t.parameters.copy() -%} + {%- for key, value in modified_params.items() -%} + {%- if 'default' in value -%} + {%- set _ = value.pop('default', None) -%} + {%- endif -%} + {%- endfor -%} + {%- set tparams = modified_params | tojson -%} + Use the function '{{ tname }}' to '{{ tdesc }}': + {"name": "{{tname}}", "description": "{{tdesc}}", "parameters": {{tparams}}} + + {% endfor -%} + Think very carefully before calling functions. + If you choose to call a function ONLY reply in the following format with no prefix or suffix: + + {"example_name": "example_value"} + + Reminder: + - If looking for real time information use relevant functions before falling back to brave_search + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Put the entire function call reply on one line + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + {"custom_tools": [t.model_dump() for t in custom_tools]}, + ) + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="trending_songs", + description="Returns the trending songs on a Music site", + parameters={ + "n": ToolParamDefinition( + param_type="int", + description="The number of songs to return", + required=True, + ), + "genre": ToolParamDefinition( + param_type="str", + description="The genre of the songs to return", + required=False, + ), + }, + ), + ] + ] + + +class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 + DEFAULT_PROMPT = textwrap.dedent( + """ + You are an expert in composing functions. You are given a question and a set of possible functions. + Based on the question, you will need to make one or more function/tool calls to achieve the purpose. + If none of the function can be used, point it out. If the given question lacks the parameters required by the function, + also point it out. You should only return the function call in tools call sections. + + {{ function_description }} + """.strip("\n") + ) + + def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate: + system_prompt = system_prompt or self.DEFAULT_PROMPT + return PromptTemplate( + system_prompt, + {"function_description": self._gen_function_description(custom_tools)}, + ) + + def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] + You SHOULD NOT include any other text in the response. + + Here is a list of functions in JSON format that you can invoke. + + [ + {% for t in tools -%} + {# manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set tparams = t.parameters -%} + {%- set required_params = [] -%} + {%- for name, param in tparams.items() if param.required == true -%} + {%- set _ = required_params.append(name) -%} + {%- endfor -%} + { + "name": "{{tname}}", + "description": "{{tdesc}}", + "parameters": { + "type": "dict", + "required": {{ required_params | tojson }}, + "properties": { + {%- for name, param in tparams.items() %} + "{{name}}": { + "type": "{{param.param_type}}", + "description": "{{param.description}}"{% if param.default %}, + "default": "{{param.default}}"{% endif %} + }{% if not loop.last %},{% endif %} + {%- endfor %} + } + } + }{% if not loop.last %}, + {% endif -%} + {%- endfor %} + ] + """ + ) + return PromptTemplate( + template_str.strip("\n"), + {"tools": [t.model_dump() for t in custom_tools]}, + ).render() + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="get_weather", + description="Get weather info for places", + parameters={ + "city": ToolParamDefinition( + param_type="string", + description="The name of the city to get the weather for", + required=True, + ), + "metric": ToolParamDefinition( + param_type="string", + description="The metric for weather. Options are: celsius, fahrenheit", + required=False, + default="celsius", + ), + }, + ), + ] + ] diff --git a/llama_stack/models/llama/llama3/prompt_templates/tool_response.py b/llama_stack/models/llama/llama3/prompt_templates/tool_response.py new file mode 100644 index 000000000..3df4dac14 --- /dev/null +++ b/llama_stack/models/llama/llama3/prompt_templates/tool_response.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from typing import Optional + +from .base import PromptTemplate, PromptTemplateGeneratorBase + + +class ToolResponseGenerator(PromptTemplateGeneratorBase): + def gen( + self, + status: str, + stdout: Optional[str] = None, + stderr: Optional[str] = None, + ): + assert status in [ + "success", + "failure", + ], f"status must be 'success' or 'failure'; Got: {status}" + template_str = textwrap.dedent( + """ + {% if status == "success" %}completed{% else %}failed{% endif %} + {%- if stdout %} + [stdout]{{ stdout }}[/stdout] + {%- endif -%} + {%- if stderr %} + [stderr]{{ stderr }}[/stderr] + {%- endif -%} + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + { + "status": status, + "stdout": stdout, + "stderr": stderr, + }, + ) + + def data_examples(self): + return [ + # success + { + "status": "success", + "stdout": '{"results":["something something"]}', + }, + # failure + { + "status": "failure", + "stderr": "brave_search encounter an error: could not communicate with api.brave.com", + }, + ] diff --git a/llama_stack/models/llama/llama3/template_data.py b/llama_stack/models/llama/llama3/template_data.py new file mode 100644 index 000000000..620816ffc --- /dev/null +++ b/llama_stack/models/llama/llama3/template_data.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from llama_models.datatypes import ( + BuiltinTool, + StopReason, + ToolCall, +) + +from .prompt_templates import ( + BuiltinToolGenerator, + JsonCustomToolGenerator, + ToolResponseGenerator, +) + +INSTRUCTION = "You are a helpful assistant." + + +def system_message_builtin_tools_only(): + return { + "builtin_tools": BuiltinToolGenerator().data_examples()[0], + "custom_tools": [], + "instruction": INSTRUCTION, + } + + +def system_message_builtin_code_only(): + return { + "builtin_tools": BuiltinToolGenerator().data_examples()[1], + "custom_tools": [], + "instruction": "", + } + + +def system_message_custom_tools_only(): + return { + "builtin_tools": [], + "custom_tools": JsonCustomToolGenerator().data_examples()[0], + "instruction": INSTRUCTION, + } + + +def system_message_builtin_and_custom_tools(): + return { + "builtin_tools": BuiltinToolGenerator().data_examples()[0], + "custom_tools": JsonCustomToolGenerator().data_examples()[0], + "instruction": INSTRUCTION, + } + + +def system_default(): + return { + "builtin_tools": [], + "custom_tools": [], + "instruction": INSTRUCTION, + } + + +def tool_success(): + return ToolResponseGenerator().data_examples()[0] + + +def tool_failure(): + return ToolResponseGenerator().data_examples()[1] + + +def assistant_builtin_tool_call(): + return { + "content": "", + "tool_call": ToolCall( + call_id="uuid", + tool_name=BuiltinTool.brave_search, + arguments={ + "query": "Who won NBA in 2024?", + }, + ), + "stop_reason": StopReason.end_of_message, + } + + +def assistant_custom_tool_call(): + return { + "content": "", + "tool_call": ToolCall( + call_id="uuid", + tool_name="trending_songs", + arguments={"country": "US", "n": 10}, + ), + "stop_reason": StopReason.end_of_turn, + } + + +def assistant_default(): + return { + "content": "Hi, I am a helpful assistant. What can I help you with today?", + "tool_call": None, + "stop_reason": StopReason.end_of_turn, + } + + +def user_default(): + return {"content": "Please tell me how to plan a trip to New York"} + + +def user_images(): + return {"content": "<|image|><|image|>What do these images depict?"} + + +def user_interleaved_images(): + return {"content": "<|image|>Describe the image in one sentence.<|image|>Write a haiku about these images"} diff --git a/llama_stack/models/llama/llama3/test_system_prompts.py b/llama_stack/models/llama/llama3/test_system_prompts.py new file mode 100644 index 000000000..b47b1ff2d --- /dev/null +++ b/llama_stack/models/llama/llama3/test_system_prompts.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +import unittest +from datetime import datetime + +from .prompt_templates import ( + BuiltinToolGenerator, + FunctionTagCustomToolGenerator, + JsonCustomToolGenerator, + PythonListCustomToolGenerator, + SystemDefaultGenerator, +) + + +class PromptTemplateTests(unittest.TestCase): + def check_generator_output(self, generator, expected_text): + example = generator.data_examples()[0] + + pt = generator.gen(example) + text = pt.render() + # print(text) # debugging + assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}" + + def test_system_default(self): + generator = SystemDefaultGenerator() + today = datetime.now().strftime("%d %B %Y") + expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}" + self.check_generator_output(generator, expected_text) + + def test_system_builtin_only(self): + generator = BuiltinToolGenerator() + expected_text = textwrap.dedent( + """ + Environment: ipython + Tools: brave_search, wolfram_alpha + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_system_custom_only(self): + self.maxDiff = None + generator = JsonCustomToolGenerator() + expected_text = textwrap.dedent( + """ + Answer the user's question by making use of the following functions if needed. + If none of the function can be used, please say so. + Here is a list of functions in JSON format: + { + "type": "function", + "function": { + "name": "trending_songs", + "description": "Returns the trending songs on a Music site", + "parameters": { + "type": "object", + "properties": [ + { + "n": { + "type": "object", + "description": "The number of songs to return" + } + }, + { + "genre": { + "type": "object", + "description": "The genre of the songs to return" + } + } + ], + "required": ["n"] + } + } + } + + Return function calls in JSON format. + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_system_custom_function_tag(self): + self.maxDiff = None + generator = FunctionTagCustomToolGenerator() + expected_text = textwrap.dedent( + """ + You have access to the following functions: + + Use the function 'trending_songs' to 'Returns the trending songs on a Music site': + {"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}} + + Think very carefully before calling functions. + If you choose to call a function ONLY reply in the following format with no prefix or suffix: + + {"example_name": "example_value"} + + Reminder: + - If looking for real time information use relevant functions before falling back to brave_search + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Put the entire function call reply on one line + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_llama_3_2_system_zero_shot(self): + generator = PythonListCustomToolGenerator() + expected_text = textwrap.dedent( + """ + You are an expert in composing functions. You are given a question and a set of possible functions. + Based on the question, you will need to make one or more function/tool calls to achieve the purpose. + If none of the function can be used, point it out. If the given question lacks the parameters required by the function, + also point it out. You should only return the function call in tools call sections. + + If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] + You SHOULD NOT include any other text in the response. + + Here is a list of functions in JSON format that you can invoke. + + [ + { + "name": "get_weather", + "description": "Get weather info for places", + "parameters": { + "type": "dict", + "required": ["city"], + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get the weather for" + }, + "metric": { + "type": "string", + "description": "The metric for weather. Options are: celsius, fahrenheit", + "default": "celsius" + } + } + } + } + ] + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_llama_3_2_provided_system_prompt(self): + generator = PythonListCustomToolGenerator() + expected_text = textwrap.dedent( + """ + Overriding message. + + If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] + You SHOULD NOT include any other text in the response. + + Here is a list of functions in JSON format that you can invoke. + + [ + { + "name": "get_weather", + "description": "Get weather info for places", + "parameters": { + "type": "dict", + "required": ["city"], + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get the weather for" + }, + "metric": { + "type": "string", + "description": "The metric for weather. Options are: celsius, fahrenheit", + "default": "celsius" + } + } + } + } + ]""" + ) + user_system_prompt = textwrap.dedent( + """ + Overriding message. + + {{ function_description }} + """ + ) + example = generator.data_examples()[0] + + pt = generator.gen(example, user_system_prompt) + text = pt.render() + assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}" diff --git a/llama_stack/models/llama/llama3_1/__init__.py b/llama_stack/models/llama/llama3_1/__init__.py new file mode 100644 index 000000000..38ee47d66 --- /dev/null +++ b/llama_stack/models/llama/llama3_1/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. diff --git a/llama_stack/models/llama/llama3_1/prompts.py b/llama_stack/models/llama/llama3_1/prompts.py new file mode 100644 index 000000000..edbce3bc0 --- /dev/null +++ b/llama_stack/models/llama/llama3_1/prompts.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from typing import List + +from llama_models.datatypes import ( + BuiltinTool, + RawMessage, + StopReason, + ToolCall, + ToolPromptFormat, +) + +from ..prompt_format import ( + # llama3_1_e2e_tool_call_dialog, + TextCompletionContent, + UseCase, + llama3_1_builtin_tool_call_dialog, + llama3_1_custom_tool_call_dialog, +) + + +def wolfram_alpha_response(): + return textwrap.dedent( + """ + { + "queryresult": { + "success": true, + "inputstring": "100th decimal of pi", + "pods": [ + { + "title": "Input interpretation", + "subpods": [ + { + "title": "", + "plaintext": "100th digit | \u03c0" + } + ] + }, + { + "title": "Nearby digits", + "subpods": [ + { + "title": "", + "plaintext": "...86208998628034825342117067982148086513282306647093..." + } + ] + }, + { + "title": "Result", + "primary": true, + "subpods": [ + { + "title": "", + "plaintext": "7" + } + ] + } + ] + } + } + """ + ) + + +def usecases() -> List[UseCase | str]: + return [ + textwrap.dedent( + """ + # Llama 3.1 - Prompt Formats + ## Tokens + Here is a list of special tokens that are supported by Llama 3.1: + - `<|begin_of_text|>`: Specifies the start of the prompt + - `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models. + - `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch. + - `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and tool] + - `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool. + - `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios: + - at the end of a direct interaction between the model and the user + - at the end of multiple interactions between the model and any available tools + This token signals to the executor that the model has finished generating a response. + - `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call. + """ + ), + textwrap.dedent( + """ + There are 4 different roles that are supported by Llama 3.1 + - `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively. + - `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model. + - `tool`: A new role introduced in Llama 3.1. This role is used to mark messages with the output of a tool call when sent back to the model from the executor. (The actual token used by the model for this role is "ipython".) + - `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts. + """ + ), + UseCase( + title="Llama 3.1 Base Model", + description="Text completion for Llama 3.1 base model uses this format.", + dialogs=[TextCompletionContent(content="Color of sky is blue but sometimes can also be")], + notes="Note start special tag", + ), + "## Llama 3.1 Instruct Model", + UseCase( + title="User and assistant conversation", + description="Here is a regular multi-turn user assistant conversation and how its formatted.", + dialogs=[ + [ + RawMessage(role="system", content="You are a helpful assistant"), + RawMessage( + role="user", + content="Answer who are you in the form of jeopardy?", + ), + ] + ], + notes="", + ), + "## Tool Calling Formats", + textwrap.dedent( + """ + The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt: + - Brave Search: Tool call to perform web searches. + - Wolfram Alpha: Tool call to perform complex mathematical calculations. + - Code Interpreter: Enables the model to output python code. + """ + ), + UseCase( + title="Builtin Tool Calling", + description=textwrap.dedent( + """ + Here is an example of a conversation using brave search + """ + ), + dialogs=[llama3_1_builtin_tool_call_dialog()], + notes=textwrap.dedent( + """ + - Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model. + - The message body of the assistant response starts with a special tag <|python_tag|> + - As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call. + - The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha` + """ + ), + ), + UseCase( + title="Builtin Code Interpreter", + description="Here is an actual example of model responding with code", + dialogs=[ + [ + RawMessage(role="system", content="Environment: ipython"), + RawMessage( + role="user", + content="Write code to check if number is prime, use that to see if the number 7 is prime", + ), + ], + ], + notes=textwrap.dedent( + """ + - Model starts with <|python_tag|> and continues writing python code that it needs to be executed + - No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it. + """ + ), + ), + UseCase( + title="Built-in tools full interaction", + description="Here is a full interaction with the built-in tools including the tool response and the final assistant response.", + dialogs=[ + [ + RawMessage( + role="system", + content="Environment: ipython\nTools: brave_search, wolfram_alpha\n", + ), + RawMessage(role="user", content="What is the 100th decimal of pi?"), + RawMessage( + role="assistant", + content="", + stop_reason=StopReason.end_of_message, + tool_calls=[ + ToolCall( + call_id="tool_call_id", + tool_name=BuiltinTool.wolfram_alpha, + arguments={"query": "100th decimal of pi"}, + ) + ], + ), + RawMessage( + role="tool", + content=wolfram_alpha_response(), + ), + ], + ], + notes=textwrap.dedent( + """ + - Note the `<|python_tag|>` in the assistant response. + - Role is `tool` for the wolfram alpha response that is passed back to the model. + - Final message from assistant has <|eot_id|> tag. + """ + ), + ), + "## Zero shot tool calling", + UseCase( + title="JSON based tool calling", + description=textwrap.dedent( + """ + Llama models can now output custom tool calls from a single message to allow easier tool calling. + The following prompts provide an example of how custom tools can be called from the output of the model. + It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor. + """ + ), + dialogs=[llama3_1_custom_tool_call_dialog()], + notes=textwrap.dedent( + """ + - JSON format for providing tools needs name, description and parameters + - Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt + - Instructions for tools added as a user message + - Only single tool calls are supported as of now + """ + ), + ), + # FIXME: This is not working yet as expected + # UseCase( + # title="E2E tool call example", + # description=textwrap.dedent( + # """ + # Here is an example showing the whole multi-step turn by taking custom tool outputs and passing back to the model. + # """ + # ), + # dialogs=[ + # llama3_1_e2e_tool_call_dialog( + # tool_prompt_format=ToolPromptFormat.function_tag + # ) + # ], + # notes="", + # ), + "## Example of a user defined tool calling", + UseCase( + title="`` based tool calling", + description=textwrap.dedent( + """ + Here is an example of how you could also write custom instructions for model to do zero shot tool calling. + In this example, we define a custom tool calling format using the `` tag. + """ + ), + dialogs=[llama3_1_custom_tool_call_dialog(ToolPromptFormat.function_tag)], + notes=textwrap.dedent( + """ + - In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>` + - Instructions for tools added as a user message + """ + ), + ), + ] diff --git a/llama_stack/models/llama/llama3_2/__init__.py b/llama_stack/models/llama/llama3_2/__init__.py new file mode 100644 index 000000000..38ee47d66 --- /dev/null +++ b/llama_stack/models/llama/llama3_2/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. diff --git a/llama_stack/models/llama/llama3_2/prompts_text.py b/llama_stack/models/llama/llama3_2/prompts_text.py new file mode 100644 index 000000000..29557f4be --- /dev/null +++ b/llama_stack/models/llama/llama3_2/prompts_text.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. +import json +import textwrap + +from llama_models.datatypes import ( + RawMessage, + StopReason, + ToolCall, + ToolPromptFormat, +) + +from ..prompt_format import ( + TextCompletionContent, + UseCase, + llama3_1_builtin_code_interpreter_dialog, +) + + +def user_tool_call(): + content = textwrap.dedent( + """ + Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request? + Here is a list of functions in JSON format that you can invoke: + [ + { + "name": "get_user_info", + "description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.", + "parameters": { + "type": "dict", + "required": [ + "user_id" + ], + "properties": { + "user_id": { + "type": "integer", + "description": "The unique identifier of the user. It is used to fetch the specific user details from the database." + }, + "special": { + "type": "string", + "description": "Any special information or parameters that need to be considered while fetching user details.", + "default": "none" + } + } + } + } + ] + + Should you decide to return the function call(s),Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)] + + NO other text MUST be included. + """ + ) + return content.strip() + + +def system_tool_call(): + content = textwrap.dedent( + """ + You are an expert in composing functions. You are given a question and a set of possible functions. + Based on the question, you will need to make one or more function/tool calls to achieve the purpose. + If none of the function can be used, point it out. If the given question lacks the parameters required by the function, + also point it out. You should only return the function call in tools call sections. + + If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] + You SHOULD NOT include any other text in the response. + + Here is a list of functions in JSON format that you can invoke. + + [ + { + "name": "get_weather", + "description": "Get weather info for places", + "parameters": { + "type": "dict", + "required": [ + "city" + ], + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get the weather for" + }, + "metric": { + "type": "string", + "description": "The metric for weather. Options are: celsius, fahrenheit", + "default": "celsius" + } + } + } + } + ] + """ + ) + return content.strip() + + +def usecases(): + return [ + UseCase( + title="User and assistant conversation", + description="Here is a regular multi-turn user assistant conversation and how its formatted.", + dialogs=[ + [ + RawMessage(role="system", content="You are a helpful assistant"), + RawMessage(role="user", content="Who are you?"), + ] + ], + notes="This format is unchanged from Llama3.1", + ), + UseCase( + title="Zero shot function calling", + description=textwrap.dedent( + """ + For Llama3.2 1B and 3B instruct models, we are introducing a new format for zero shot function calling. + This new format is designed to be more flexible and powerful than the previous format. + All available functions can be provided in the system message. A key difference is in the format of how the assistant responds with function calls. + It is pythonic in the form of `[func1(params_name=params_value, params_name2=params_value2...), func2(params)]` instead of the `json` or `` tag that were defined in Llama3.1. + Here is an example for the same, + """ + ), + dialogs=[ + # Zero shot tool calls as system message + [ + RawMessage(role="system", content=system_tool_call()), + RawMessage(role="user", content="What is the weather in SF and Seattle?"), + ], + ], + notes=textwrap.dedent( + """ + - The output supports multiple tool calls natively + - JSON format for defining the functions in the system prompt is similar to Llama3.1 + """ + ), + ), + UseCase( + title="Zero shot function calling with user message", + description=textwrap.dedent( + """ + While the default is to provide all function calls in a system message, in Llama3.2 text models you can also provide information for all the available tools in a user message. + """ + ), + dialogs=[ + # Zero shot tool call as user message + [ + RawMessage(role="user", content=user_tool_call()), + ], + ], + notes=textwrap.dedent( + """ + - The tool call format for the model is the same whether your function calls are provided in the system or user message. + - While builtin tool calls end with a <|eom_id|>, notice the <|eot_id|> for zero shot tool calls. + """ + ), + ), + UseCase( + title="Code Interpreter", + description=textwrap.dedent( + """ + Code Interpreter continues to work in 3.2 text models similar to Llama 3.1 model family. + Here is an example, + """ + ), + dialogs=[llama3_1_builtin_code_interpreter_dialog()], + notes=textwrap.dedent( + """ + - Note `Environment: ipython` in the system prompt. + - Note that the response starts with `<|python_tag|>` and ends with `<|eom_id|>` + """ + ), + ), + UseCase( + title="Zero shot function calling E2E format", + description=textwrap.dedent( + """ + Here is an example of the e2e cycle of tool calls with the model in a muti-step way. + """ + ), + dialogs=[ + [ + RawMessage(role="system", content=system_tool_call()), + RawMessage(role="user", content="What is the weather in SF?"), + RawMessage( + role="assistant", + content="", + stop_reason=StopReason.end_of_turn, + tool_calls=[ + ToolCall( + call_id="cc", + tool_name="get_weather", + arguments={ + "city": "San Francisco", + "metric": "celsius", + }, + ) + ], + ), + RawMessage( + role="tool", + content=json.dumps("25 C"), + ), + ], + ], + notes=textwrap.dedent( + """ + - The output of the function call is provided back to the model as a tool response ( in json format ). + - Notice `<|start_header_id|>ipython<|end_header_id|>` as the header message preceding the tool response. + - The model finally summarizes the information from the tool response and returns the result to the user. + """ + ), + tool_prompt_format=ToolPromptFormat.python_list, + ), + UseCase( + title="Prompt format for base models", + description=textwrap.dedent( + """ + For base models (Llama3.2-1B and Llama3.2-3B), the prompt format for a simple completion is as follows + """ + ), + dialogs=[ + TextCompletionContent(content="The color of the sky is blue but sometimes it can also be"), + ], + notes="Same as Llama3.1", + ), + ] diff --git a/llama_stack/models/llama/llama3_2/prompts_vision.py b/llama_stack/models/llama/llama3_2/prompts_vision.py new file mode 100644 index 000000000..c3cfe5e7b --- /dev/null +++ b/llama_stack/models/llama/llama3_2/prompts_vision.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from pathlib import Path + +from llama_models.datatypes import ( + RawMediaItem, + RawMessage, + RawTextItem, +) + +from ..prompt_format import ( + TextCompletionContent, + UseCase, + llama3_1_builtin_tool_call_dialog, + # llama3_1_builtin_tool_call_with_image_dialog, + llama3_2_user_assistant_conversation, +) + + +def usecases(): + this_dir = Path(__file__).parent.parent.resolve() + with open(this_dir / "scripts/resources/dog.jpg", "rb") as f: + img = f.read() + + return [ + llama3_2_user_assistant_conversation(), + UseCase( + title="User and assistant conversation with Images", + description="This example shows how to pass and image to the model as part of the messages.", + dialogs=[ + [ + RawMessage( + role="user", + content=[ + RawMediaItem(data=img), + RawTextItem(text="Describe this image in two sentences"), + ], + ) + ], + ], + notes=textwrap.dedent( + """ + - The `<|image|>` tag is used to indicate presence of the image + - The model isn't an early fusion model so doesn't actually translate an image into several tokens. Instead the cross-attention layers take input "on the side" from a vision encoder + ![Image](mm-model.png) + - Its important to postion the <|image|> tag appropriately in the prompt. Image will only attend to the subsequent text tokens + - The <|image|> tag is part of the user message body, implying that it should only come after the header `<|start_header_id|>{role}<|end_header_id|>` in the message body + - We recommend using a single image in one prompt + """ + ), + ), + UseCase( + title="Builtin and Zero Shot Tool Calling", + description=textwrap.dedent( + """ + Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only. + Use `Environment: ipython` to enable tools. + Add `Tools: {{tool_name1}},{{tool_name2}}` for each of the builtin tools. + The same builtin tools as Llama3.1 are available, + - code_interpreter (for executing python code) + - brave_search (to search the web) + - wolfram_alpha (for querying wolfram alpha for mathematical questions) + """, + ), + dialogs=[llama3_1_builtin_tool_call_dialog()], + notes=textwrap.dedent( + """ + - Note the `<|python_tag|>` before `brave_search` function call. + - The `<|eom_id|>` tag is used to indicate the end of the message. + - Similar to Llama3.1, code_interpreter is not explicitly mentioned but is enabled via `Environment: ipython`. + - Tool Calling does NOT work with images in the prompt as of now. + """ + ), + ), + # UseCase( + # title="Tool Calling for vision models", + # description=textwrap.dedent( + # """ + # While Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only, + # they are not able to do tool calling when prompt contains image inputs (along with text). + # The recommended way would be to separate out the image understanding from the tool calling in successive prompts. + # Here is an example of how that could be done, + # """, + # ), + # dialogs=[llama3_1_builtin_tool_call_with_image_dialog()], + # notes=textwrap.dedent( + # """ + # - Instead of a single prompt (image understanding + tool call), we split into two prompts to achieve the same result. + # """ + # ), + # ), + UseCase( + title="Prompt format for base models", + description=textwrap.dedent( + """ + For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), the prompt format for a simple completion is as follows + """ + ), + dialogs=[ + TextCompletionContent(content="The color of the sky is blue but sometimes it can also be"), + ], + notes="- Same as Llama3.1", + ), + UseCase( + title="Prompt format for base models with Image", + description=textwrap.dedent( + """ + For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), here is an example of how the text completion format looks with an image, + """ + ), + dialogs=[ + TextCompletionContent( + content=[ + RawMediaItem(data=img), + RawTextItem(text="If I had to write a haiku for this one"), + ] + ), + ], + notes="- Note the placement of the special tags <|begin_of_text|> and <|image|>", + ), + ] diff --git a/llama_stack/models/llama/llama3_3/prompts.py b/llama_stack/models/llama/llama3_3/prompts.py new file mode 100644 index 000000000..14fd86853 --- /dev/null +++ b/llama_stack/models/llama/llama3_3/prompts.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from typing import List + +from llama_models.datatypes import ( + BuiltinTool, + RawMessage, + StopReason, + ToolCall, + ToolPromptFormat, +) + +from ..prompt_format import ( + # llama3_1_e2e_tool_call_dialog, + TextCompletionContent, + UseCase, + llama3_1_builtin_tool_call_dialog, + llama3_1_custom_tool_call_dialog, +) + + +def wolfram_alpha_response(): + return textwrap.dedent( + """ + { + "queryresult": { + "success": true, + "inputstring": "100th decimal of pi", + "pods": [ + { + "title": "Input interpretation", + "subpods": [ + { + "title": "", + "plaintext": "100th digit | \u03c0" + } + ] + }, + { + "title": "Nearby digits", + "subpods": [ + { + "title": "", + "plaintext": "...86208998628034825342117067982148086513282306647093..." + } + ] + }, + { + "title": "Result", + "primary": true, + "subpods": [ + { + "title": "", + "plaintext": "7" + } + ] + } + ] + } + } + """ + ) + + +def usecases() -> List[UseCase | str]: + return [ + textwrap.dedent( + """ + # Llama 3.1 - Prompt Formats + ## Tokens + Here is a list of special tokens that are supported by Llama 3.1: + - `<|begin_of_text|>`: Specifies the start of the prompt + - `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models. + - `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch. + - `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and tool] + - `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool. + - `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios: + - at the end of a direct interaction between the model and the user + - at the end of multiple interactions between the model and any available tools + This token signals to the executor that the model has finished generating a response. + - `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call. + """ + ), + textwrap.dedent( + """ + There are 4 different roles that are supported by Llama 3.1 + - `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively. + - `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model. + - `tool`: A new role introduced in Llama 3.1. This role is used to mark messages with the output of a tool call when sent back to the model from the executor. (The actual token used by the model for this role is "ipython".) + - `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts. + """ + ), + UseCase( + title="Llama 3.1 Base Model", + description="Text completion for Llama 3.1 base model uses this format.", + dialogs=[TextCompletionContent(content="Color of sky is blue but sometimes can also be")], + notes="Note start special tag", + ), + "## Llama 3.1 Instruct Model", + UseCase( + title="User and assistant conversation", + description="Here is a regular multi-turn user assistant conversation and how its formatted.", + dialogs=[ + [ + RawMessage(role="system", content="You are a helpful assistant"), + RawMessage( + role="user", + content="Answer who are you in the form of jeopardy?", + ), + ] + ], + notes="", + ), + "## Tool Calling Formats", + textwrap.dedent( + """ + The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt: + - Brave Search: Tool call to perform web searches. + - Wolfram Alpha: Tool call to perform complex mathematical calculations. + - Code Interpreter: Enables the model to output python code. + """ + ), + UseCase( + title="Builtin Tool Calling", + description=textwrap.dedent( + """ + Here is an example of a conversation using brave search + """ + ), + dialogs=[llama3_1_builtin_tool_call_dialog()], + notes=textwrap.dedent( + """ + - Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model. + - The message body of the assistant response starts with a special tag <|python_tag|> + - As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call. + - The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha` + """ + ), + ), + UseCase( + title="Builtin Code Interpreter", + description="Here is an actual example of model responding with code", + dialogs=[ + [ + RawMessage(role="system", content="Environment: ipython"), + RawMessage( + role="user", + content="Write code to check if number is prime, use that to see if the number 7 is prime", + ), + ], + ], + notes=textwrap.dedent( + """ + - Model starts with <|python_tag|> and continues writing python code that it needs to be executed + - No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it. + """ + ), + ), + UseCase( + title="Built-in tools full interaction", + description="Here is a full interaction with the built-in tools including the tool response and the final assistant response.", + dialogs=[ + [ + RawMessage( + role="system", + content="Environment: ipython\nTools: brave_search, wolfram_alpha\n", + ), + RawMessage(role="user", content="What is the 100th decimal of pi?"), + RawMessage( + content="", + stop_reason=StopReason.end_of_message, + tool_calls=[ + ToolCall( + call_id="tool_call_id", + tool_name=BuiltinTool.wolfram_alpha, + arguments={"query": "100th decimal of pi"}, + ) + ], + ), + RawMessage( + role="tool", + content=wolfram_alpha_response(), + ), + ], + ], + notes=textwrap.dedent( + """ + - Note the `<|python_tag|>` in the assistant response. + - Role is `tool` for the wolfram alpha response that is passed back to the model. + - Final message from assistant has <|eot_id|> tag. + """ + ), + ), + "## Zero shot tool calling", + UseCase( + title="JSON based tool calling", + description=textwrap.dedent( + """ + Llama models can now output custom tool calls from a single message to allow easier tool calling. + The following prompts provide an example of how custom tools can be called from the output of the model. + It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor. + """ + ), + dialogs=[llama3_1_custom_tool_call_dialog()], + notes=textwrap.dedent( + """ + - JSON format for providing tools needs name, description and parameters + - Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt + - Instructions for tools added as a user message + - Only single tool calls are supported as of now + """ + ), + ), + # FIXME: This is not working yet as expected + # UseCase( + # title="E2E tool call example", + # description=textwrap.dedent( + # """ + # Here is an example showing the whole multi-step turn by taking custom tool outputs and passing back to the model. + # """ + # ), + # dialogs=[ + # llama3_1_e2e_tool_call_dialog( + # tool_prompt_format=ToolPromptFormat.function_tag + # ) + # ], + # notes="", + # ), + "## Example of a user defined tool calling", + UseCase( + title="`` based tool calling", + description=textwrap.dedent( + """ + Here is an example of how you could also write custom instructions for model to do zero shot tool calling. + In this example, we define a custom tool calling format using the `` tag. + """ + ), + dialogs=[llama3_1_custom_tool_call_dialog(ToolPromptFormat.function_tag)], + notes=textwrap.dedent( + """ + - In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>` + - Instructions for tools added as a user message + """ + ), + ), + ] diff --git a/llama_stack/models/llama/prompt_format.py b/llama_stack/models/llama/prompt_format.py new file mode 100644 index 000000000..f42620d57 --- /dev/null +++ b/llama_stack/models/llama/prompt_format.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import json +import textwrap +from pathlib import Path +from typing import List + +from llama_models.datatypes import ( + RawContent, + RawMediaItem, + RawMessage, + RawTextItem, + StopReason, + ToolCall, + ToolPromptFormat, +) +from pydantic import BaseModel, Field + +from .llama3.interface import LLama31Interface +from .llama3.template_data import ( + system_message_builtin_code_only, + system_message_builtin_tools_only, + system_message_custom_tools_only, +) + + +class TextCompletionContent(BaseModel): + content: RawContent = "" + + +class UseCase(BaseModel): + title: str = "" + description: str = "" + dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list) + notes: str = "" + tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json + + def md_format(self): + section = textwrap.dedent( + """ + ## {title} + + {description} + + {dialogs_text} + {notes} + + """ + ) + return section.lstrip() + + def dialogs_to_text(self, generator) -> str: + def _code_block(text): + return f"```\n{text}\n```" + + text = "" + for dialog in self.dialogs: + if isinstance(dialog, str): + text += dialog + text += "\n\n" + continue + + elif isinstance(dialog, TextCompletionContent): + input_tokens, output_tokens = generator.text_completion_raw( + dialog.content, + max_gen_len=64, + temperature=0.1, + top_p=0.95, + ) + else: + input_tokens, output_tokens = generator.chat_completion_raw( + dialog, + max_gen_len=512, + temperature=0.0, + top_p=0.95, + tool_prompt_format=self.tool_prompt_format, + ) + text += "##### Input Prompt Format\n" + + # FIXME: This is added to undo the hack in chat_formatter where + # vision tokens are replaced with 128256. + input_tokens = [generator.formatter.vision_token if t == 128256 else t for t in input_tokens] + + text += _code_block(generator.tokenizer.decode(input_tokens)) + # TODO: Figure out if "↵" needs to be added for newlines or end or some indication + text += "\n\n" + text += "##### Model Response Format\n" + text += _code_block(generator.tokenizer.decode(output_tokens)) + text += "\n\n" + + return text + + def to_text(self, generator): + section = self.md_format() + dialogs_text = self.dialogs_to_text(generator) + notes = f"##### Notes\n{self.notes}" if self.notes else "" + section = section.format( + title=self.title, + description=self.description, + dialogs_text=dialogs_text, + notes=notes, + ) + return section + + +def llama3_1_builtin_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json): + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_builtin_tools_only()) + messages += interface.user_message(content="Search the web for the latest price of 1oz gold?") + + return messages + + +def llama3_1_builtin_code_interpreter_dialog(tool_prompt_format=ToolPromptFormat.json): + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_builtin_code_only()) + messages += interface.user_message( + content="Write code to check if number is prime. Use it to verify if number 7 is prime" + ) + + return messages + + +def llama3_1_builtin_tool_call_with_image_dialog( + tool_prompt_format=ToolPromptFormat.json, +): + this_dir = Path(__file__).parent + with open(this_dir / "llama3/dog.jpg", "rb") as f: + img = f.read() + + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_builtin_tools_only()) + messages += interface.user_message(content=[RawMediaItem(data=img), RawTextItem(text="What is this dog breed?")]) + messages += interface.assistant_response_messages( + "Based on the description of the dog in the image, it appears to be a small breed dog, possibly a terrier mix", + StopReason.end_of_turn, + ) + messages += interface.user_message("Search the web for some food recommendations for the indentified breed") + return messages + + +def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json): + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_custom_tools_only()) + messages += interface.user_message(content="Use tools to get latest trending songs") + return messages + + +def llama3_1_e2e_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json): + tool_response = json.dumps(["great song1", "awesome song2", "cool song3"]) + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_custom_tools_only()) + messages += interface.user_message(content="Use tools to get latest trending songs") + messages.append( + RawMessage( + role="assistant", + content="", + stop_reason=StopReason.end_of_message, + tool_calls=[ + ToolCall( + call_id="call_id", + tool_name="trending_songs", + arguments={"n": "10", "genre": "latest"}, + ) + ], + ), + ) + messages.append( + RawMessage( + role="assistant", + content=tool_response, + ) + ) + return messages + + +def llama3_2_user_assistant_conversation(): + return UseCase( + title="User and assistant conversation", + description="Here is a regular multi-turn user assistant conversation and how its formatted.", + dialogs=[ + [ + RawMessage(role="system", content="You are a helpful assistant"), + RawMessage(role="user", content="Who are you?"), + ] + ], + notes="This format is unchanged from Llama3.1", + ) diff --git a/llama_stack/models/llama/sku_list.py b/llama_stack/models/llama/sku_list.py new file mode 100644 index 000000000..6f4a5a885 --- /dev/null +++ b/llama_stack/models/llama/sku_list.py @@ -0,0 +1,1000 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from dataclasses import dataclass +from functools import lru_cache +from typing import List, Optional + +from .datatypes import ( + CheckpointQuantizationFormat, + CoreModelId, + Model, + SamplingParams, + TopPSamplingStrategy, +) + +LLAMA2_VOCAB_SIZE = 32000 +LLAMA3_VOCAB_SIZE = 128256 + + +def resolve_model(descriptor: str) -> Optional[Model]: + for m in all_registered_models(): + if descriptor in (m.descriptor(), m.huggingface_repo): + return m + return None + + +def all_registered_models() -> List[Model]: + return ( + llama2_family() + llama3_family() + llama3_1_family() + llama3_2_family() + llama3_3_family() + safety_models() + ) + + +def recommended_sampling_params() -> SamplingParams: + return SamplingParams( + strategy=TopPSamplingStrategy( + temperature=1.0, + top_p=0.9, + ) + ) + + +def llama2_family() -> List[Model]: + return [ + *llama2_base_models(), + *llama2_instruct_models(), + ] + + +def llama3_family() -> List[Model]: + return [ + *llama3_base_models(), + *llama3_instruct_models(), + ] + + +def llama3_1_family() -> List[Model]: + return [ + *llama3_1_base_models(), + *llama3_1_instruct_models(), + ] + + +def llama3_2_family() -> List[Model]: + return [ + *llama3_2_base_models(), + *llama3_2_instruct_models(), + ] + + +def llama3_3_family() -> List[Model]: + return [ + *llama3_3_instruct_models(), + ] + + +def llama2_base_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama2_7b, + description="Llama 2 7b model", + huggingface_repo="meta-llama/Llama-2-7b", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama2_13b, + description="Llama 2 13b model", + huggingface_repo="meta-llama/Llama-2-13b", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 5120, + "n_layers": 40, + "n_heads": 40, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama2_70b, + description="Llama 2 70b model", + huggingface_repo="meta-llama/Llama-2-70b", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=8, + ), + ] + + +def llama3_base_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_8b, + description="Llama 3 8b model", + huggingface_repo="meta-llama/Llama-3-8B", + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_70b, + description="Llama 3 70b model", + huggingface_repo="meta-llama/Llama-3-70B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=8, + ), + ] + + +def llama3_1_base_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_1_8b, + description="Llama 3.1 8b model", + huggingface_repo="meta-llama/Llama-3.1-8B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_1_70b, + description="Llama 3.1 70b model", + huggingface_repo="meta-llama/Llama-3.1-70B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b, + variant="bf16-mp8", + description="Llama 3.1 405b model (BF16 weights)", + huggingface_repo="meta-llama/Llama-3.1-405B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b, + description="Llama 3.1 405b model (FP8 quantized)", + huggingface_repo="meta-llama/Llama-3.1-405B-FP8", + quantization_format=CheckpointQuantizationFormat.fp8_mixed, + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b, + variant="bf16-mp16", + description="Llama 3.1 405b model (BF16 weights for mp16)", + huggingface_repo="meta-llama/Llama-3.1-405B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 16, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=16, + ), + ] + + +def llama3_2_base_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_2_1b, + description="Llama 3.2 1b model", + huggingface_repo="meta-llama/Llama-3.2-1B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 2048, + "n_layers": 16, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.5, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_3b, + description="Llama 3.2 3b model", + huggingface_repo="meta-llama/Llama-3.2-3B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 3072, + "n_layers": 28, + "n_heads": 24, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.0, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_11b_vision, + description="Llama 3.2 11b vision model", + huggingface_repo="meta-llama/Llama-3.2-11B-Vision", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 448, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 8, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_90b_vision, + description="Llama 3.2 90b vision model", + huggingface_repo="meta-llama/Llama-3.2-90B-Vision", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 20, + }, + pth_file_count=8, + ), + ] + + +def llama2_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama2_7b_chat, + description="Llama 2 7b chat model", + huggingface_repo="meta-llama/Llama-2-7b-chat", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama2_13b_chat, + description="Llama 2 13b chat model", + huggingface_repo="meta-llama/Llama-2-13b-chat", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 5120, + "n_layers": 40, + "n_heads": 40, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama2_70b_chat, + description="Llama 2 70b chat model", + huggingface_repo="meta-llama/Llama-2-70b-chat", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=8, + ), + ] + + +def llama3_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_8b_instruct, + description="Llama 3 8b instruct model", + huggingface_repo="meta-llama/Llama-3-8B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_70b_instruct, + description="Llama 3 70b instruct model", + huggingface_repo="meta-llama/Llama-3-70B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=8, + ), + ] + + +def llama3_1_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_1_8b_instruct, + description="Llama 3.1 8b instruct model", + huggingface_repo="meta-llama/Llama-3.1-8B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_1_70b_instruct, + description="Llama 3.1 70b instruct model", + huggingface_repo="meta-llama/Llama-3.1-70B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b_instruct, + variant="bf16-mp8", + description="Llama 3.1 405b instruct model (BF16 weights)", + huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b_instruct, + description="Llama 3.1 405b instruct model (FP8 quantized)", + huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8", + quantization_format=CheckpointQuantizationFormat.fp8_mixed, + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b_instruct, + variant="bf16-mp16", + description="Llama 3.1 405b instruct model (BF16 weights for mp16)", + huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 16, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=16, + ), + ] + + +def arch_args_1b() -> dict: + return { + "dim": 2048, + "n_layers": 16, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.5, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + } + + +def arch_args_3b() -> dict: + return { + "dim": 3072, + "n_layers": 28, + "n_heads": 24, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.0, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + } + + +def llama3_2_quantized_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_2_1b_instruct, + variant="int4-qlora-eo8", + quantization_format=CheckpointQuantizationFormat.int4, + description="Llama 3.2 1b INT4 quantized LoRA", + huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + **arch_args_1b(), + "quantization_args": { + "group_size": 256, + }, + "lora_args": { + "rank": 16, + "scale": 2.0, + }, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_1b_instruct, + variant="int4-spinquant-eo8", + quantization_format=CheckpointQuantizationFormat.int4, + description="Llama 3.2 1b INT4 quantized SpinQuant", + huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + **arch_args_1b(), + "quantization_args": { + "group_size": 256, + }, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_3b_instruct, + variant="int4-qlora-eo8", + quantization_format=CheckpointQuantizationFormat.int4, + description="Llama 3.2 3b INT4 quantized LoRA", + huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + **arch_args_3b(), + "quantization_args": { + "group_size": 256, + }, + "lora_args": { + "rank": 16, + "scale": 2.0, + }, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_3b_instruct, + variant="int4-spinquant-eo8", + quantization_format=CheckpointQuantizationFormat.int4, + description="Llama 3.2 3b INT4 quantized SpinQuant", + huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + **arch_args_3b(), + "quantization_args": { + "group_size": 256, + }, + }, + pth_file_count=1, + ), + ] + + +def llama3_2_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_2_1b_instruct, + description="Llama 3.2 1b instruct model", + huggingface_repo="meta-llama/Llama-3.2-1B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args=arch_args_1b(), + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_3b_instruct, + description="Llama 3.2 3b instruct model", + huggingface_repo="meta-llama/Llama-3.2-3B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args=arch_args_3b(), + pth_file_count=1, + ), + *llama3_2_quantized_models(), + Model( + core_model_id=CoreModelId.llama3_2_11b_vision_instruct, + description="Llama 3.2 11b vision instruct model", + huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 8, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_90b_vision_instruct, + description="Llama 3.2 90b vision instruct model", + huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 20, + }, + pth_file_count=8, + ), + ] + + +def llama3_3_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_3_70b_instruct, + description="Llama 3.3 70b instruct", + huggingface_repo="meta-llama/Llama-3.3-70B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + ] + + +@lru_cache +def safety_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama_guard_3_11b_vision, + description="Llama Guard v3 11b vision system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 8, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_3_1b, + variant="int4", + description="Llama Guard v3 1b 'int4' quantized system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4", + quantization_format=CheckpointQuantizationFormat.int4, + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 2048, + "n_layers": 12, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "rope_freq_base": 500000.0, + "norm_eps": 1e-05, + "hidden_dim": 6400, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_3_1b, + description="Llama Guard v3 1b system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-1B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 2048, + "n_layers": 16, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.5, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_3_8b, + description="Llama Guard v3 8b system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-8B", + arch_args={ + "dim": 4096, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "n_heads": 32, + "n_kv_heads": 8, + "n_layers": 32, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + "vocab_size": LLAMA3_VOCAB_SIZE, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_3_8b, + variant="int8", + description="Llama Guard v3 8b system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-8B-INT8", + quantization_format=CheckpointQuantizationFormat.int8, + arch_args={ + "dim": 4096, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "n_heads": 32, + "n_kv_heads": 8, + "n_layers": 32, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + "vocab_size": LLAMA3_VOCAB_SIZE, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_2_8b, + description="Llama Guard v2 8b system safety model", + huggingface_repo="meta-llama/Llama-Guard-2-8B", + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + ] + + +@dataclass +class LlamaDownloadInfo: + folder: str + files: List[str] + pth_size: int + + +def llama_meta_net_info(model: Model) -> LlamaDownloadInfo: + """Information needed to download model from llamameta.net""" + + pth_count = model.pth_file_count + if model.core_model_id == CoreModelId.llama3_1_405b: + if pth_count == 16: + folder = "Llama-3.1-405B-MP16" + elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: + folder = "Llama-3.1-405B" + else: + folder = "Llama-3.1-405B-MP8" + elif model.core_model_id == CoreModelId.llama3_1_405b_instruct: + if pth_count == 16: + folder = "Llama-3.1-405B-Instruct-MP16" + elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: + folder = "Llama-3.1-405B-Instruct" + else: + folder = "Llama-3.1-405B-Instruct-MP8" + elif model.core_model_id == CoreModelId.llama_guard_3_8b: + if model.quantization_format == CheckpointQuantizationFormat.int8: + folder = "Llama-Guard-3-8B-INT8-HF" + else: + folder = "Llama-Guard-3-8B" + elif model.core_model_id == CoreModelId.llama_guard_2_8b: + folder = "llama-guard-2" + else: + folder = model.huggingface_repo.split("/")[-1] + if "Llama-2" in folder: + folder = folder.lower() + + files = ["checklist.chk"] + if ( + model.core_model_id == CoreModelId.llama_guard_3_8b + and model.quantization_format == CheckpointQuantizationFormat.int8 + ): + files.extend( + [ + "generation_config.json", + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + "special_tokens_map.json", + "tokenizer.json", + "tokenizer_config.json", + "model.safetensors.index.json", + ] + ) + elif ( + model.core_model_id == CoreModelId.llama_guard_3_1b + and model.quantization_format == CheckpointQuantizationFormat.int4 + ): + files.extend( + [ + "llama_guard_3_1b_pruned_xnnpack.pte", + "example-prompt.txt", + "params.json", + "tokenizer.model", + ] + ) + else: + files.extend( + [ + "tokenizer.model", + "params.json", + ] + ) + if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: + files.extend([f"fp8_scales_{i}.pt" for i in range(pth_count)]) + files.extend([f"consolidated.{i:02d}.pth" for i in range(pth_count)]) + + return LlamaDownloadInfo( + folder=folder, + files=files, + pth_size=llama_meta_pth_size(model), + ) + + +# Sadness because Cloudfront rejects our HEAD requests to find Content-Length +def llama_meta_pth_size(model: Model) -> int: + if model.core_model_id not in ( + CoreModelId.llama3_1_405b, + CoreModelId.llama3_1_405b_instruct, + ): + return 0 + + if model.pth_file_count == 16: + return 51268302389 + elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: + return 60903742309 + else: + return 101470976045 diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index ccdaf76e7..384582423 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -7,17 +7,17 @@ from typing import Any, List, Optional, Protocol from urllib.parse import urlparse -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.datasets import Dataset from llama_stack.apis.datatypes import Api -from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.models import Model from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield from llama_stack.apis.tools import Tool from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.schema_utils import json_schema_type class ModelsProtocolPrivate(Protocol): @@ -48,8 +48,8 @@ class ScoringFunctionsProtocolPrivate(Protocol): async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ... -class EvalTasksProtocolPrivate(Protocol): - async def register_eval_task(self, eval_task: EvalTask) -> None: ... +class BenchmarksProtocolPrivate(Protocol): + async def register_benchmark(self, benchmark: Benchmark) -> None: ... class ToolsProtocolPrivate(Protocol): diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 8ba7885cd..1c21df57f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple from urllib.parse import urlparse import httpx -from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition from pydantic import TypeAdapter from llama_stack.apis.agents import ( @@ -63,6 +62,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO +from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing @@ -301,6 +301,7 @@ class ChatAgent(ShieldRunnerMixin): return step_id = str(uuid.uuid4()) + shield_call_start_time = datetime.now() try: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -323,6 +324,8 @@ class ChatAgent(ShieldRunnerMixin): step_id=step_id, turn_id=turn_id, violation=e.violation, + started_at=shield_call_start_time, + completed_at=datetime.now(), ), ) ) @@ -344,6 +347,8 @@ class ChatAgent(ShieldRunnerMixin): step_id=step_id, turn_id=turn_id, violation=None, + started_at=shield_call_start_time, + completed_at=datetime.now(), ), ) ) @@ -476,6 +481,7 @@ class ChatAgent(ShieldRunnerMixin): client_tools[tool.name] = tool while True: step_id = str(uuid.uuid4()) + inference_start_time = datetime.now() yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( @@ -574,6 +580,8 @@ class ChatAgent(ShieldRunnerMixin): step_id=step_id, turn_id=turn_id, model_response=copy.deepcopy(message), + started_at=inference_start_time, + completed_at=datetime.now(), ), ) ) @@ -641,6 +649,7 @@ class ChatAgent(ShieldRunnerMixin): "input": message.model_dump_json(), }, ) as span: + tool_execution_start_time = datetime.now() result_messages = await execute_tool_call_maybe( self.tool_runtime_api, session_id, @@ -668,6 +677,8 @@ class ChatAgent(ShieldRunnerMixin): content=result_message.content, ) ], + started_at=tool_execution_start_time, + completed_at=datetime.now(), ), ) ) diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 30ce52e3b..2497be070 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -41,7 +41,7 @@ class ShieldRunnerMixin: for identifier in identifiers ] ) - for identifier, response in zip(identifiers, responses): + for identifier, response in zip(identifiers, responses, strict=False): if not response.violation: continue diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index 4e3951ad3..b802937b6 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -8,7 +8,6 @@ import tempfile from typing import AsyncIterator, List, Optional, Union import pytest -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( AgentConfig, @@ -41,6 +40,7 @@ from llama_stack.apis.tools import ( ToolInvocationResult, ) from llama_stack.apis.vector_io import QueryChunksResponse +from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( MEMORY_QUERY_TOOL, ) diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 1c44caf7f..0f77b7347 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -8,13 +8,13 @@ from typing import Any, Dict, List, Optional from tqdm import tqdm from llama_stack.apis.agents import Agents, StepType +from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets -from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.inference import Inference, UserMessage from llama_stack.apis.scoring import Scoring from llama_stack.distribution.datatypes import Api -from llama_stack.providers.datatypes import EvalTasksProtocolPrivate +from llama_stack.providers.datatypes import BenchmarksProtocolPrivate from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( MEMORY_QUERY_TOOL, ) @@ -26,15 +26,15 @@ from llama_stack.providers.utils.common.data_schema_validator import ( from llama_stack.providers.utils.kvstore import kvstore_impl from .....apis.common.job_types import Job -from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus +from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus from .config import MetaReferenceEvalConfig -EVAL_TASKS_PREFIX = "eval_tasks:" +EVAL_TASKS_PREFIX = "benchmarks:" class MetaReferenceEvalImpl( Eval, - EvalTasksProtocolPrivate, + BenchmarksProtocolPrivate, ): def __init__( self, @@ -55,36 +55,36 @@ class MetaReferenceEvalImpl( # TODO: assume sync job, will need jobs API for async scheduling self.jobs = {} - self.eval_tasks = {} + self.benchmarks = {} async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) - # Load existing eval_tasks from kvstore + # Load existing benchmarks from kvstore start_key = EVAL_TASKS_PREFIX end_key = f"{EVAL_TASKS_PREFIX}\xff" - stored_eval_tasks = await self.kvstore.range(start_key, end_key) + stored_benchmarks = await self.kvstore.range(start_key, end_key) - for eval_task in stored_eval_tasks: - eval_task = EvalTask.model_validate_json(eval_task) - self.eval_tasks[eval_task.identifier] = eval_task + for benchmark in stored_benchmarks: + benchmark = Benchmark.model_validate_json(benchmark) + self.benchmarks[benchmark.identifier] = benchmark async def shutdown(self) -> None: ... - async def register_eval_task(self, task_def: EvalTask) -> None: + async def register_benchmark(self, task_def: Benchmark) -> None: # Store in kvstore key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}" await self.kvstore.set( key=key, value=task_def.model_dump_json(), ) - self.eval_tasks[task_def.identifier] = task_def + self.benchmarks[task_def.identifier] = task_def async def run_eval( self, - task_id: str, - task_config: EvalTaskConfig, + benchmark_id: str, + task_config: BenchmarkConfig, ) -> Job: - task_def = self.eval_tasks[task_id] + task_def = self.benchmarks[benchmark_id] dataset_id = task_def.dataset_id candidate = task_config.eval_candidate scoring_functions = task_def.scoring_functions @@ -95,7 +95,7 @@ class MetaReferenceEvalImpl( rows_in_page=(-1 if task_config.num_examples is None else task_config.num_examples), ) res = await self.evaluate_rows( - task_id=task_id, + benchmark_id=benchmark_id, input_rows=all_rows.rows, scoring_functions=scoring_functions, task_config=task_config, @@ -108,7 +108,7 @@ class MetaReferenceEvalImpl( return Job(job_id=job_id) async def _run_agent_generation( - self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig + self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig ) -> List[Dict[str, Any]]: candidate = task_config.eval_candidate create_response = await self.agents_api.create_agent(candidate.config) @@ -151,7 +151,7 @@ class MetaReferenceEvalImpl( return generations async def _run_model_generation( - self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig + self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig ) -> List[Dict[str, Any]]: candidate = task_config.eval_candidate assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided" @@ -187,10 +187,10 @@ class MetaReferenceEvalImpl( async def evaluate_rows( self, - task_id: str, + benchmark_id: str, input_rows: List[Dict[str, Any]], scoring_functions: List[str], - task_config: EvalTaskConfig, + task_config: BenchmarkConfig, ) -> EvaluateResponse: candidate = task_config.eval_candidate if candidate.type == "agent": @@ -201,9 +201,11 @@ class MetaReferenceEvalImpl( raise ValueError(f"Invalid candidate type: {candidate.type}") # scoring with generated_answer - score_input_rows = [input_r | generated_r for input_r, generated_r in zip(input_rows, generations)] + score_input_rows = [ + input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False) + ] - if task_config.type == "app" and task_config.scoring_params is not None: + if task_config.scoring_params is not None: scoring_functions_dict = { scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None) for scoring_fn_id in scoring_functions @@ -217,18 +219,60 @@ class MetaReferenceEvalImpl( return EvaluateResponse(generations=generations, scores=score_response.results) - async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: + async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: if job_id in self.jobs: return JobStatus.completed return None - async def job_cancel(self, task_id: str, job_id: str) -> None: + async def job_cancel(self, benchmark_id: str, job_id: str) -> None: raise NotImplementedError("Job cancel is not implemented yet") - async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: - status = await self.job_status(task_id, job_id) + async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: + status = await self.job_status(benchmark_id, job_id) if not status or status != JobStatus.completed: raise ValueError(f"Job is not completed, Status: {status.value}") return self.jobs[job_id] + + async def DEPRECATED_run_eval( + self, + task_id: str, + task_config: BenchmarkConfig, + ) -> Job: + return await self.run_eval(benchmark_id=task_id, task_config=task_config) + + async def DEPRECATED_evaluate_rows( + self, + task_id: str, + input_rows: List[Dict[str, Any]], + scoring_functions: List[str], + task_config: BenchmarkConfig, + ) -> EvaluateResponse: + return await self.evaluate_rows( + benchmark_id=task_id, + input_rows=input_rows, + scoring_functions=scoring_functions, + task_config=task_config, + ) + + async def DEPRECATED_job_status( + self, + task_id: str, + job_id: str, + ) -> Optional[JobStatus]: + return await self.job_status(benchmark_id=task_id, job_id=job_id) + + async def DEPRECATED_job_cancel( + self, + task_id: str, + job_id: str, + ) -> None: + return await self.job_cancel(benchmark_id=task_id, job_id=job_id) + + async def DEPRECATED_job_result( + self, + task_id: str, + job_id: str, + ) -> EvaluateResponse: + return await self.job_result(benchmark_id=task_id, job_id=job_id) diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index e60c3b1be..2d2ec5c8f 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -23,20 +23,13 @@ from fairscale.nn.model_parallel.initialize import ( initialize_model_parallel, model_parallel_is_initialized, ) -from llama_models.datatypes import ( - GreedySamplingStrategy, - SamplingParams, - TopPSamplingStrategy, -) from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.chat_format import ChatFormat, LLMInput -from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.multimodal.model import ( CrossAttentionTransformer, ) -from llama_models.sku_list import resolve_model from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from pydantic import BaseModel @@ -47,6 +40,13 @@ from llama_stack.apis.inference import ( ResponseFormatType, ) from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.models.llama.datatypes import ( + GreedySamplingStrategy, + Model, + SamplingParams, + TopPSamplingStrategy, +) +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 61f0ee3f4..c79f97def 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -8,14 +8,6 @@ import asyncio import logging from typing import AsyncGenerator, List, Optional, Union -from llama_models.llama3.api.datatypes import ( - SamplingParams, - StopReason, - ToolDefinition, - ToolPromptFormat, -) -from llama_models.sku_list import resolve_model - from llama_stack.apis.common.content_types import ( TextDelta, ToolCallDelta, @@ -41,6 +33,13 @@ from llama_stack.apis.inference import ( ToolConfig, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.models.llama.datatypes import ( + SamplingParams, + StopReason, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index ef133274c..64f94a69d 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -10,10 +10,10 @@ from functools import partial from typing import Any, Generator from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model +from llama_stack.models.llama.datatypes import Model +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py index 9be35ae70..a2dc00916 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py @@ -14,14 +14,14 @@ from typing import Any, Dict, List, Optional import torch from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region -from llama_models.datatypes import CheckpointQuantizationFormat from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock -from llama_models.sku_list import resolve_model from torch import Tensor, nn from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from llama_stack.apis.inference import QuantizationType +from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat +from llama_stack.models.llama.sku_list import resolve_model from ..config import MetaReferenceQuantizedInferenceConfig diff --git a/llama_stack/providers/inline/inference/vllm/config.py b/llama_stack/providers/inline/inference/vllm/config.py index de2bae265..51ef2d273 100644 --- a/llama_stack/providers/inline/inference/vllm/config.py +++ b/llama_stack/providers/inline/inference/vllm/config.py @@ -4,10 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field, field_validator from llama_stack.providers.utils.inference import supported_inference_models +from llama_stack.schema_utils import json_schema_type @json_schema_type diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index e75a9aac3..5536ea3a5 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -11,7 +11,6 @@ from typing import AsyncGenerator, List, Optional from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams as VLLMSamplingParams @@ -35,6 +34,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 735af8c79..98e16f9d7 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -13,8 +13,6 @@ from typing import Any, Callable, Dict import torch -from llama_models.datatypes import Model -from llama_models.sku_list import resolve_model from pydantic import BaseModel from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages from torchtune.models.llama3 import llama3_tokenizer @@ -24,6 +22,8 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b from torchtune.modules.transforms import Transform from llama_stack.apis.post_training import DatasetFormat +from llama_stack.models.llama.datatypes import Model +from llama_stack.models.llama.sku_list import resolve_model class ModelConfig(BaseModel): diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index ba11736d6..c77d9305f 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -6,8 +6,6 @@ from datetime import datetime from typing import Any, Dict, Optional -from llama_models.schema_utils import webmethod - from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( @@ -27,6 +25,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import ( from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( LoraFinetuningSingleDevice, ) +from llama_stack.schema_utils import webmethod class TorchtunePostTrainingImpl: diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index ef379aff2..4ab59fec4 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -14,7 +14,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch -from llama_models.sku_list import resolve_model from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler @@ -46,6 +45,7 @@ from llama_stack.apis.post_training import ( ) from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.inline.post_training.common.validator import ( validate_input_dataset_schema, ) diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 32d6d5100..af0987fa8 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -8,9 +8,6 @@ import re from string import Template from typing import Any, Dict, List, Optional -from llama_models.datatypes import CoreModelId -from llama_models.llama3.api.datatypes import Role - from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.inference import ( ChatCompletionResponseEventType, @@ -26,6 +23,7 @@ from llama_stack.apis.safety import ( ) from llama_stack.apis.shields import Shield from llama_stack.distribution.datatypes import Api +from llama_stack.models.llama.datatypes import CoreModelId, Role from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py index f28ae248c..1850d69f7 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py @@ -83,12 +83,6 @@ import sys as _sys from contextlib import ( # noqa contextmanager as _contextmanager, ) -from contextlib import ( - redirect_stderr as _redirect_stderr, -) -from contextlib import ( - redirect_stdout as _redirect_stdout, -) from multiprocessing.connection import Connection as _Connection # Mangle imports to avoid polluting model execution namespace. diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 5695d4037..a6cd57923 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -118,7 +118,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): return RAGQueryResult(content=None) # sort by score - chunks, scores = zip(*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)) + chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) tokens = 0 picked = [] diff --git a/llama_stack/providers/inline/vector_io/faiss/config.py b/llama_stack/providers/inline/vector_io/faiss/config.py index ae859842d..9eae9ed67 100644 --- a/llama_stack/providers/inline/vector_io/faiss/config.py +++ b/llama_stack/providers/inline/vector_io/faiss/config.py @@ -6,13 +6,13 @@ from typing import Any, Dict -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel from llama_stack.providers.utils.kvstore.config import ( KVStoreConfig, SqliteKVStoreConfig, ) +from llama_stack.schema_utils import json_schema_type @json_schema_type diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index b52fb074c..410d8bd8b 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -103,7 +103,7 @@ class FaissIndex(EmbeddingIndex): chunks = [] scores = [] - for d, i in zip(distances[0], indices[0]): + for d, i in zip(distances[0], indices[0], strict=False): if i < 0: continue chunks.append(self.chunk_by_index[int(i)]) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index fcd7cd8f9..6c787bc29 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -80,7 +80,7 @@ class SQLiteVecIndex(EmbeddingIndex): try: # Start transaction cur.execute("BEGIN TRANSACTION") - for chunk, emb in zip(chunks, embeddings): + for chunk, emb in zip(chunks, embeddings, strict=False): # Serialize and insert the chunk metadata. chunk_json = chunk.model_dump_json() cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,)) diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 4422baba5..88a65397a 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -42,7 +42,7 @@ def available_providers() -> List[ProviderSpec]: provider_type="inline::meta-reference", pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], module="llama_stack.providers.inline.vector_io.faiss", - config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig", + config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig", deprecation_warning="Please use the `inline::faiss` provider instead.", api_dependencies=[Api.inference], ), @@ -51,7 +51,7 @@ def available_providers() -> List[ProviderSpec]: provider_type="inline::faiss", pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], module="llama_stack.providers.inline.vector_io.faiss", - config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig", + config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig", api_dependencies=[Api.inference], ), InlineProviderSpec( @@ -68,7 +68,7 @@ def available_providers() -> List[ProviderSpec]: adapter_type="chromadb", pip_packages=EMBEDDING_DEPS + ["chromadb-client"], module="llama_stack.providers.remote.vector_io.chroma", - config_class="llama_stack.providers.remote.vector_io.chroma.ChromaRemoteImplConfig", + config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig", ), api_dependencies=[Api.inference], ), @@ -77,7 +77,7 @@ def available_providers() -> List[ProviderSpec]: provider_type="inline::chromadb", pip_packages=EMBEDDING_DEPS + ["chromadb"], module="llama_stack.providers.inline.vector_io.chroma", - config_class="llama_stack.providers.inline.vector_io.chroma.ChromaInlineImplConfig", + config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig", api_dependencies=[Api.inference], ), remote_provider_spec( @@ -86,7 +86,7 @@ def available_providers() -> List[ProviderSpec]: adapter_type="pgvector", pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"], module="llama_stack.providers.remote.vector_io.pgvector", - config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorConfig", + config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig", ), api_dependencies=[Api.inference], ), @@ -96,7 +96,7 @@ def available_providers() -> List[ProviderSpec]: adapter_type="weaviate", pip_packages=EMBEDDING_DEPS + ["weaviate-client"], module="llama_stack.providers.remote.vector_io.weaviate", - config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateConfig", + config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig", provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", ), api_dependencies=[Api.inference], @@ -107,7 +107,7 @@ def available_providers() -> List[ProviderSpec]: adapter_type="sample", pip_packages=[], module="llama_stack.providers.remote.vector_io.sample", - config_class="llama_stack.providers.remote.vector_io.sample.SampleConfig", + config_class="llama_stack.providers.remote.vector_io.sample.SampleVectorIOConfig", ), api_dependencies=[], ), @@ -117,7 +117,7 @@ def available_providers() -> List[ProviderSpec]: adapter_type="qdrant", pip_packages=EMBEDDING_DEPS + ["qdrant-client"], module="llama_stack.providers.remote.vector_io.qdrant", - config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantConfig", + config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig", ), api_dependencies=[Api.inference], ), diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 917ac7a25..e896f0597 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -8,7 +8,6 @@ import json from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from botocore.client import BaseClient -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -28,6 +27,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.inference.model_registry import ( diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 2158fc5b4..1ce267e8d 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -7,9 +7,7 @@ from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import TopKSamplingStrategy from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent @@ -28,6 +26,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import CoreModelId, TopKSamplingStrategy from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/cerebras/config.py b/llama_stack/providers/remote/inference/cerebras/config.py index 6eb4dffec..81682c980 100644 --- a/llama_stack/providers/remote/inference/cerebras/config.py +++ b/llama_stack/providers/remote/inference/cerebras/config.py @@ -7,9 +7,10 @@ import os from typing import Any, Dict, Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field, SecretStr +from llama_stack.schema_utils import json_schema_type + DEFAULT_BASE_URL = "https://api.cerebras.ai" diff --git a/llama_stack/providers/remote/inference/databricks/config.py b/llama_stack/providers/remote/inference/databricks/config.py index ae2b056ea..6aaf7e594 100644 --- a/llama_stack/providers/remote/inference/databricks/config.py +++ b/llama_stack/providers/remote/inference/databricks/config.py @@ -5,9 +5,10 @@ # the root directory of this source tree. -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class DatabricksImplConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index d56be1465..3d306e61f 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -6,7 +6,6 @@ from typing import AsyncGenerator, List, Optional -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI @@ -25,6 +24,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index aa4c2d1de..005dfe829 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -6,9 +6,10 @@ from typing import Any, Dict, Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field, SecretStr +from llama_stack.schema_utils import json_schema_type + @json_schema_type class FireworksImplConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 7e8f85313..acf37b248 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -7,7 +7,6 @@ from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -30,6 +29,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/groq/config.py b/llama_stack/providers/remote/inference/groq/config.py index 7c5023410..cb2619437 100644 --- a/llama_stack/providers/remote/inference/groq/config.py +++ b/llama_stack/providers/remote/inference/groq/config.py @@ -6,9 +6,10 @@ from typing import Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class GroqConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 59ec8b0d2..441b6af5c 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -9,9 +9,6 @@ from typing import AsyncIterator, List, Optional, Union import groq from groq import Groq -from llama_models.datatypes import SamplingParams -from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat -from llama_models.sku_list import CoreModelId from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -29,6 +26,8 @@ from llama_stack.apis.inference import ( ToolConfig, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat +from llama_stack.models.llama.sku_list import CoreModelId from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 2445c1b39..f1138e789 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -24,7 +24,6 @@ from groq.types.chat.chat_completion_user_message_param import ( ) from groq.types.chat.completion_create_params import CompletionCreateParams from groq.types.shared.function_definition import FunctionDefinition -from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_stack.apis.common.content_types import ( TextDelta, @@ -44,6 +43,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import ToolParamDefinition from llama_stack.providers.utils.inference.openai_compat import ( UnparseableToolCall, convert_tool_call, diff --git a/llama_stack/providers/remote/inference/nvidia/config.py b/llama_stack/providers/remote/inference/nvidia/config.py index 9bf5eb469..abd34b498 100644 --- a/llama_stack/providers/remote/inference/nvidia/config.py +++ b/llama_stack/providers/remote/inference/nvidia/config.py @@ -7,9 +7,10 @@ import os from typing import Any, Dict, Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field, SecretStr +from llama_stack.schema_utils import json_schema_type + @json_schema_type class NVIDIAConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 82343513f..8e67333af 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -4,12 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging import warnings from typing import AsyncIterator, List, Optional, Union -from llama_models.datatypes import SamplingParams -from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat -from llama_models.sku_list import CoreModelId from openai import APIConnectionError, AsyncOpenAI from llama_stack.apis.inference import ( @@ -28,6 +26,12 @@ from llama_stack.apis.inference import ( ToolChoice, ToolConfig, ) +from llama_stack.models.llama.datatypes import ( + CoreModelId, + SamplingParams, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, @@ -45,6 +49,8 @@ from .openai_utils import ( ) from .utils import _is_nvidia_hosted, check_health +logger = logging.getLogger(__name__) + _MODEL_ALIASES = [ build_model_alias( "meta/llama3-8b-instruct", @@ -92,7 +98,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # TODO(mf): filter by available models ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES) - print(f"Initializing NVIDIAInferenceAdapter({config.url})...") + logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...") if _is_nvidia_hosted(config): if not config.api_key: diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index c757c562c..9799eedcc 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -8,17 +8,6 @@ import json import warnings from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union -from llama_models.datatypes import ( - GreedySamplingStrategy, - TopKSamplingStrategy, - TopPSamplingStrategy, -) -from llama_models.llama3.api.datatypes import ( - BuiltinTool, - StopReason, - ToolCall, - ToolDefinition, -) from openai import AsyncStream from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, @@ -87,6 +76,15 @@ from llama_stack.apis.inference import ( ToolResponseMessage, UserMessage, ) +from llama_stack.models.llama.datatypes import ( + BuiltinTool, + GreedySamplingStrategy, + StopReason, + ToolCall, + ToolDefinition, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) diff --git a/llama_stack/providers/remote/inference/nvidia/utils.py b/llama_stack/providers/remote/inference/nvidia/utils.py index 0ec80e9dd..7d3f3f27e 100644 --- a/llama_stack/providers/remote/inference/nvidia/utils.py +++ b/llama_stack/providers/remote/inference/nvidia/utils.py @@ -4,12 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging from typing import Tuple import httpx from . import NVIDIAConfig +logger = logging.getLogger(__name__) + def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: return "integrate.api.nvidia.com" in config.url @@ -42,7 +45,7 @@ async def check_health(config: NVIDIAConfig) -> None: RuntimeError: If the server is not running or ready """ if not _is_nvidia_hosted(config): - print("Checking NVIDIA NIM health...") + logger.info("Checking NVIDIA NIM health...") try: is_live, is_ready = await _get_health(config.url) if not is_live: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 1c12d0d91..f524c0734 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -8,7 +8,6 @@ import logging from typing import AsyncGenerator, List, Optional, Union import httpx -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient @@ -34,6 +33,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/runpod/config.py b/llama_stack/providers/remote/inference/runpod/config.py index 1a9582052..e59cfe59b 100644 --- a/llama_stack/providers/remote/inference/runpod/config.py +++ b/llama_stack/providers/remote/inference/runpod/config.py @@ -6,9 +6,10 @@ from typing import Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class RunpodImplConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index a3c615418..1abb17336 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -6,11 +6,11 @@ from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.models.llama.datatypes import Message # from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper diff --git a/llama_stack/providers/remote/inference/sambanova/config.py b/llama_stack/providers/remote/inference/sambanova/config.py index 1798841df..a30c29b74 100644 --- a/llama_stack/providers/remote/inference/sambanova/config.py +++ b/llama_stack/providers/remote/inference/sambanova/config.py @@ -6,9 +6,10 @@ from typing import Any, Dict, Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class SambaNovaImplConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 3546ee977..b906e0dcb 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -7,12 +7,6 @@ import json from typing import AsyncGenerator -from llama_models.datatypes import ( - CoreModelId, - GreedySamplingStrategy, - TopKSamplingStrategy, - TopPSamplingStrategy, -) from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI @@ -23,6 +17,12 @@ from llama_stack.apis.common.content_types import ( TextContentItem, ) from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.models.llama.datatypes import ( + CoreModelId, + GreedySamplingStrategy, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/tgi/config.py b/llama_stack/providers/remote/inference/tgi/config.py index 4f690dec6..6ad663662 100644 --- a/llama_stack/providers/remote/inference/tgi/config.py +++ b/llama_stack/providers/remote/inference/tgi/config.py @@ -6,9 +6,10 @@ from typing import Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field, SecretStr +from llama_stack.schema_utils import json_schema_type + @json_schema_type class TGIImplConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 72eaa6c31..1909e01f8 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -11,7 +11,6 @@ from typing import AsyncGenerator, List, Optional from huggingface_hub import AsyncInferenceClient, HfApi from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import all_registered_models from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( @@ -31,6 +30,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/together/config.py b/llama_stack/providers/remote/inference/together/config.py index a56cb5bb8..fda3b8f43 100644 --- a/llama_stack/providers/remote/inference/together/config.py +++ b/llama_stack/providers/remote/inference/together/config.py @@ -6,9 +6,10 @@ from typing import Any, Dict, Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field, SecretStr +from llama_stack.schema_utils import json_schema_type + @json_schema_type class TogetherImplConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 916e64ad4..054501da8 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -6,7 +6,6 @@ from typing import AsyncGenerator, List, Optional, Union -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from together import Together @@ -29,6 +28,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index a3a4c6930..c75cc8926 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -6,9 +6,10 @@ from typing import Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class VLLMInferenceAdapterConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 8f9cf68a8..b22284302 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -7,10 +7,9 @@ import json import logging from typing import AsyncGenerator, List, Optional, Union -from llama_models.llama3.api import StopReason, ToolCall +from llama_models.datatypes import StopReason, ToolCall from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import all_registered_models from openai import OpenAI from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus @@ -37,6 +36,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/safety/bedrock/config.py b/llama_stack/providers/remote/safety/bedrock/config.py index 8c61decf3..1ca8d95cb 100644 --- a/llama_stack/providers/remote/safety/bedrock/config.py +++ b/llama_stack/providers/remote/safety/bedrock/config.py @@ -5,9 +5,8 @@ # the root directory of this source tree. -from llama_models.schema_utils import json_schema_type - from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig +from llama_stack.schema_utils import json_schema_type @json_schema_type diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 564f76088..8ef9f5705 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -7,7 +7,6 @@ from typing import Any, Dict, List, Optional import requests -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -18,6 +17,7 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import BraveSearchToolConfig diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index f894a8e65..3bf3a7740 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -16,12 +16,13 @@ from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) -from .config import ChromaVectorIOConfig +from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig log = logging.getLogger(__name__) @@ -68,7 +69,7 @@ class ChromaIndex(EmbeddingIndex): chunks = [] scores = [] - for dist, doc in zip(distances, documents): + for dist, doc in zip(distances, documents, strict=False): try: doc = json.loads(doc) chunk = Chunk(**doc) @@ -88,7 +89,7 @@ class ChromaIndex(EmbeddingIndex): class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( self, - config: Union[ChromaVectorIOConfig, ChromaVectorIOConfig], + config: Union[RemoteChromaVectorIOConfig, InlineChromaVectorIOConfig], inference_api: Api.inference, ) -> None: log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") @@ -99,7 +100,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self.cache = {} async def initialize(self) -> None: - if isinstance(self.config, ChromaVectorIOConfig): + if isinstance(self.config, RemoteChromaVectorIOConfig): log.info(f"Connecting to Chroma server at: {self.config.url}") url = self.config.url.rstrip("/") parsed = urlparse(url) diff --git a/llama_stack/providers/remote/vector_io/pgvector/config.py b/llama_stack/providers/remote/vector_io/pgvector/config.py index 2a64d7c67..7811de1ca 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/config.py +++ b/llama_stack/providers/remote/vector_io/pgvector/config.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class PGVectorVectorIOConfig(BaseModel): diff --git a/llama_stack/providers/remote/vector_io/qdrant/config.py b/llama_stack/providers/remote/vector_io/qdrant/config.py index 613cfa6e4..f212882d8 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/config.py +++ b/llama_stack/providers/remote/vector_io/qdrant/config.py @@ -6,9 +6,10 @@ from typing import Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel +from llama_stack.schema_utils import json_schema_type + @json_schema_type class QdrantVectorIOConfig(BaseModel): diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index e1091e2cf..586b8ca95 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -55,7 +55,7 @@ class QdrantIndex(EmbeddingIndex): ) points = [] - for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): + for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)): chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}" points.append( PointStruct( diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 45b276cc3..2e7bd537f 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -7,8 +7,6 @@ import os import pytest -from llama_models.datatypes import SamplingParams, TopPSamplingStrategy -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( AgentConfig, @@ -25,6 +23,7 @@ from llama_stack.apis.agents import ( ) from llama_stack.apis.inference import CompletionMessage, UserMessage from llama_stack.apis.safety import ViolationLevel +from llama_stack.models.llama.datatypes import BuiltinTool, SamplingParams, TopPSamplingStrategy from llama_stack.providers.datatypes import Api # How to run this test: diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index ec3d08728..ad80b8601 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -10,8 +10,8 @@ import pytest from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType from llama_stack.apis.eval.eval import ( - AppEvalTaskConfig, - BenchmarkEvalTaskConfig, + AppBenchmarkConfig, + BenchmarkBenchmarkConfig, ModelCandidate, ) from llama_stack.apis.inference import SamplingParams @@ -30,18 +30,18 @@ from .constants import JUDGE_PROMPT class Testeval: @pytest.mark.asyncio - async def test_eval_tasks_list(self, eval_stack): + async def test_benchmarks_list(self, eval_stack): # NOTE: this needs you to ensure that you are starting from a clean state # but so far we don't have an unregister API unfortunately, so be careful - eval_tasks_impl = eval_stack[Api.eval_tasks] - response = await eval_tasks_impl.list_eval_tasks() + benchmarks_impl = eval_stack[Api.benchmarks] + response = await benchmarks_impl.list_benchmarks() assert isinstance(response, list) @pytest.mark.asyncio async def test_eval_evaluate_rows(self, eval_stack, inference_model, judge_model): - eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl, models_impl = ( + eval_impl, benchmarks_impl, datasetio_impl, datasets_impl, models_impl = ( eval_stack[Api.eval], - eval_stack[Api.eval_tasks], + eval_stack[Api.benchmarks], eval_stack[Api.datasetio], eval_stack[Api.datasets], eval_stack[Api.models], @@ -59,17 +59,17 @@ class Testeval: scoring_functions = [ "basic::equality", ] - task_id = "meta-reference::app_eval" - await eval_tasks_impl.register_eval_task( - eval_task_id=task_id, + benchmark_id = "meta-reference::app_eval" + await benchmarks_impl.register_benchmark( + benchmark_id=benchmark_id, dataset_id="test_dataset_for_eval", scoring_functions=scoring_functions, ) response = await eval_impl.evaluate_rows( - task_id=task_id, + benchmark_id=benchmark_id, input_rows=rows.rows, scoring_functions=scoring_functions, - task_config=AppEvalTaskConfig( + task_config=AppBenchmarkConfig( eval_candidate=ModelCandidate( model=inference_model, sampling_params=SamplingParams(), @@ -92,9 +92,9 @@ class Testeval: @pytest.mark.asyncio async def test_eval_run_eval(self, eval_stack, inference_model, judge_model): - eval_impl, eval_tasks_impl, datasets_impl, models_impl = ( + eval_impl, benchmarks_impl, datasets_impl, models_impl = ( eval_stack[Api.eval], - eval_stack[Api.eval_tasks], + eval_stack[Api.benchmarks], eval_stack[Api.datasets], eval_stack[Api.models], ) @@ -105,15 +105,15 @@ class Testeval: "basic::subset_of", ] - task_id = "meta-reference::app_eval-2" - await eval_tasks_impl.register_eval_task( - eval_task_id=task_id, + benchmark_id = "meta-reference::app_eval-2" + await benchmarks_impl.register_benchmark( + benchmark_id=benchmark_id, dataset_id="test_dataset_for_eval", scoring_functions=scoring_functions, ) response = await eval_impl.run_eval( - task_id=task_id, - task_config=AppEvalTaskConfig( + benchmark_id=benchmark_id, + task_config=AppBenchmarkConfig( eval_candidate=ModelCandidate( model=inference_model, sampling_params=SamplingParams(), @@ -121,9 +121,9 @@ class Testeval: ), ) assert response.job_id == "0" - job_status = await eval_impl.job_status(task_id, response.job_id) + job_status = await eval_impl.job_status(benchmark_id, response.job_id) assert job_status and job_status.value == "completed" - eval_response = await eval_impl.job_result(task_id, response.job_id) + eval_response = await eval_impl.job_result(benchmark_id, response.job_id) assert eval_response is not None assert len(eval_response.generations) == 5 @@ -131,9 +131,9 @@ class Testeval: @pytest.mark.asyncio async def test_eval_run_benchmark_eval(self, eval_stack, inference_model): - eval_impl, eval_tasks_impl, datasets_impl, models_impl = ( + eval_impl, benchmarks_impl, datasets_impl, models_impl = ( eval_stack[Api.eval], - eval_stack[Api.eval_tasks], + eval_stack[Api.benchmarks], eval_stack[Api.datasets], eval_stack[Api.models], ) @@ -159,20 +159,20 @@ class Testeval: ) # register eval task - await eval_tasks_impl.register_eval_task( - eval_task_id="meta-reference-mmlu", + await benchmarks_impl.register_benchmark( + benchmark_id="meta-reference-mmlu", dataset_id="mmlu", scoring_functions=["basic::regex_parser_multiple_choice_answer"], ) # list benchmarks - response = await eval_tasks_impl.list_eval_tasks() + response = await benchmarks_impl.list_benchmarks() assert len(response) > 0 benchmark_id = "meta-reference-mmlu" response = await eval_impl.run_eval( - task_id=benchmark_id, - task_config=BenchmarkEvalTaskConfig( + benchmark_id=benchmark_id, + task_config=BenchmarkBenchmarkConfig( eval_candidate=ModelCandidate( model=inference_model, sampling_params=SamplingParams(), diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py index 3eba991c1..34725e957 100644 --- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -23,8 +23,6 @@ from groq.types.chat.chat_completion_message_tool_call import ( Function, ) from groq.types.shared.function_definition import FunctionDefinition -from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy -from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ( @@ -38,6 +36,7 @@ from llama_stack.apis.inference import ( ToolDefinition, UserMessage, ) +from llama_stack.models.llama.datatypes import GreedySamplingStrategy, ToolParamDefinition, TopPSamplingStrategy from llama_stack.providers.remote.inference.groq.groq_utils import ( convert_chat_completion_request, convert_chat_completion_response, diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index c087c5df2..323c6cb6a 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -6,19 +6,18 @@ import unittest -from llama_models.llama3.api.datatypes import ( - BuiltinTool, - ToolDefinition, - ToolParamDefinition, - ToolPromptFormat, -) - from llama_stack.apis.inference import ( ChatCompletionRequest, SystemMessage, ToolConfig, UserMessage, ) +from llama_stack.models.llama.datatypes import ( + BuiltinTool, + ToolDefinition, + ToolParamDefinition, + ToolPromptFormat, +) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 14ed2fc4b..f25b95004 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -6,14 +6,6 @@ import pytest -from llama_models.llama3.api.datatypes import ( - SamplingParams, - StopReason, - ToolCall, - ToolDefinition, - ToolParamDefinition, - ToolPromptFormat, -) from pydantic import BaseModel, ValidationError from llama_stack.apis.common.content_types import ToolCallParseStatus @@ -30,6 +22,14 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.models import ListModelsResponse, Model +from llama_stack.models.llama.datatypes import ( + SamplingParams, + StopReason, + ToolCall, + ToolDefinition, + ToolParamDefinition, + ToolPromptFormat, +) from .utils import group_chunks diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index 2f96e66d4..4d7183c49 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -88,7 +88,7 @@ class TestVisionModelInference: expected_strings_to_check = [ ["puppy"], ] - for image, expected_strings in zip(images, expected_strings_to_check): + for image, expected_strings in zip(images, expected_strings_to_check, strict=False): response = [ r async for r in await inference_impl.chat_completion( diff --git a/llama_stack/providers/tests/report.py b/llama_stack/providers/tests/report.py index 3901dc2e3..febd13045 100644 --- a/llama_stack/providers/tests/report.py +++ b/llama_stack/providers/tests/report.py @@ -9,11 +9,12 @@ from collections import defaultdict from pathlib import Path import pytest -from llama_models.datatypes import CoreModelId -from llama_models.sku_list import all_registered_models from pytest import ExitCode from pytest_html.basereport import _process_outcome +from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_list import all_registered_models + INFERENCE_APIS = ["chat_completion"] FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"] SUPPORTED_MODELS = { diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 0ff632717..76343b7f4 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -10,8 +10,8 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel +from llama_stack.apis.benchmarks import BenchmarkInput from llama_stack.apis.datasets import DatasetInput -from llama_stack.apis.eval_tasks import EvalTaskInput from llama_stack.apis.models import ModelInput from llama_stack.apis.scoring_functions import ScoringFnInput from llama_stack.apis.shields import ShieldInput @@ -42,7 +42,7 @@ async def construct_stack_for_test( vector_dbs: Optional[List[VectorDBInput]] = None, datasets: Optional[List[DatasetInput]] = None, scoring_fns: Optional[List[ScoringFnInput]] = None, - eval_tasks: Optional[List[EvalTaskInput]] = None, + benchmarks: Optional[List[BenchmarkInput]] = None, tool_groups: Optional[List[ToolGroupInput]] = None, ) -> TestStack: sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") @@ -56,7 +56,7 @@ async def construct_stack_for_test( vector_dbs=vector_dbs or [], datasets=datasets or [], scoring_fns=scoring_fns or [], - eval_tasks=eval_tasks or [], + benchmarks=benchmarks or [], tool_groups=tool_groups or [], ) run_config = parse_and_maybe_upgrade_config(run_config) diff --git a/llama_stack/providers/tests/vector_io/conftest.py b/llama_stack/providers/tests/vector_io/conftest.py index 3da64ff2e..1f9799100 100644 --- a/llama_stack/providers/tests/vector_io/conftest.py +++ b/llama_stack/providers/tests/vector_io/conftest.py @@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ ), pytest.param( { - "inference": "bedrock", + "inference": "ollama", "vector_io": "qdrant", }, id="qdrant", diff --git a/llama_stack/providers/tests/vector_io/fixtures.py b/llama_stack/providers/tests/vector_io/fixtures.py index 30a2679d7..1797d47a5 100644 --- a/llama_stack/providers/tests/vector_io/fixtures.py +++ b/llama_stack/providers/tests/vector_io/fixtures.py @@ -17,6 +17,7 @@ from llama_stack.providers.inline.vector_io.faiss import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector import PGVectorVectorIOConfig +from llama_stack.providers.remote.vector_io.qdrant import QdrantVectorIOConfig from llama_stack.providers.remote.vector_io.weaviate import WeaviateVectorIOConfig from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @@ -127,13 +128,26 @@ def vector_io_chroma() -> ProviderFixture: ) -VECTOR_IO_FIXTURES = [ - "faiss", - "pgvector", - "weaviate", - "chroma", - "sqlite_vec", -] +@pytest.fixture(scope="session") +def vector_io_qdrant() -> ProviderFixture: + url = os.getenv("QDRANT_URL") + if url: + config = QdrantVectorIOConfig(url=url) + provider_type = "remote::qdrant" + else: + raise ValueError("QDRANT_URL must be set") + return ProviderFixture( + providers=[ + Provider( + provider_id="qdrant", + provider_type=provider_type, + config=config.model_dump(), + ) + ] + ) + + +VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma", "qdrant", "sqlite_vec"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py index 64fe30f55..cab3725da 100644 --- a/llama_stack/providers/utils/inference/__init__.py +++ b/llama_stack/providers/utils/inference/__init__.py @@ -6,8 +6,8 @@ from typing import List -from llama_models.datatypes import * # noqa: F403 -from llama_models.sku_list import all_registered_models +from llama_stack.models.llama.datatypes import * # noqa: F403 +from llama_stack.models.llama.sku_list import all_registered_models def is_supported_safety_model(model: Model) -> bool: diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 9345da949..c5f6cd6b5 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -7,9 +7,8 @@ from collections import namedtuple from typing import List, Optional -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models.models import ModelType +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference import ( ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 00e291e8f..def7e8f37 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -7,14 +7,7 @@ import json import logging from typing import AsyncGenerator, Dict, List, Optional, Union -from llama_models.datatypes import ( - GreedySamplingStrategy, - SamplingParams, - TopKSamplingStrategy, - TopPSamplingStrategy, -) from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import StopReason, ToolCall from openai.types.chat import ChatCompletionMessageToolCall from pydantic import BaseModel @@ -37,6 +30,14 @@ from llama_stack.apis.inference import ( Message, TokenLogProbs, ) +from llama_stack.models.llama.datatypes import ( + GreedySamplingStrategy, + SamplingParams, + StopReason, + ToolCall, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) @@ -132,7 +133,7 @@ def convert_openai_completion_logprobs( if logprobs.tokens and logprobs.token_logprobs: return [ TokenLogProbs(logprobs_by_token={token: token_lp}) - for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs) + for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False) ] return None @@ -426,10 +427,14 @@ def convert_tool_call( """ Convert a ChatCompletionMessageToolCall tool call to either a ToolCall or UnparseableToolCall. Returns an UnparseableToolCall - if the tool call is not valid JSON. + if the tool call is not valid ToolCall. """ try: - arguments = json.loads(tool_call.function.arguments) + valid_tool_call = ToolCall( + call_id=tool_call.id, + tool_name=tool_call.function.name, + arguments=json.loads(tool_call.function.arguments), + ) except Exception as e: return UnparseableToolCall( call_id=tool_call.id or "", @@ -437,8 +442,4 @@ def convert_tool_call( arguments=tool_call.function.arguments or "", ) - return ToolCall( - call_id=tool_call.id, - tool_name=tool_call.function.name, - arguments=arguments, - ) + return valid_tool_call diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 15149e059..2782c661f 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -13,25 +13,7 @@ import re from typing import List, Optional, Tuple, Union import httpx -from llama_models.datatypes import ModelFamily, is_multimodal from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import ( - RawContent, - RawContentItem, - RawMediaItem, - RawMessage, - RawTextItem, - Role, - ToolPromptFormat, -) -from llama_models.llama3.prompt_templates import ( - BuiltinToolGenerator, - FunctionTagCustomToolGenerator, - JsonCustomToolGenerator, - PythonListCustomToolGenerator, - SystemDefaultGenerator, -) -from llama_models.sku_list import resolve_model from PIL import Image as PIL_Image from llama_stack.apis.common.content_types import ( @@ -49,8 +31,28 @@ from llama_stack.apis.inference import ( SystemMessage, SystemMessageBehavior, ToolChoice, + ToolDefinition, UserMessage, ) +from llama_stack.models.llama.datatypes import ( + ModelFamily, + RawContent, + RawContentItem, + RawMediaItem, + RawMessage, + RawTextItem, + Role, + ToolPromptFormat, + is_multimodal, +) +from llama_stack.models.llama.llama3.prompt_templates import ( + BuiltinToolGenerator, + FunctionTagCustomToolGenerator, + JsonCustomToolGenerator, + PythonListCustomToolGenerator, + SystemDefaultGenerator, +) +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference import supported_inference_models log = logging.getLogger(__name__) @@ -310,8 +312,6 @@ def response_format_prompt(fmt: Optional[ResponseFormat]): def augment_messages_for_tools_llama_3_1( request: ChatCompletionRequest, ) -> List[Message]: - assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" - existing_messages = request.messages existing_system_message = None if existing_messages[0].role == Role.system.value: @@ -351,6 +351,10 @@ def augment_messages_for_tools_llama_3_1( elif isinstance(existing_system_message.content, list): sys_content += "\n".join([_process(c) for c in existing_system_message.content]) + tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) + if tool_choice_prompt: + sys_content += "\n" + tool_choice_prompt + messages.append(SystemMessage(content=sys_content)) has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) @@ -376,8 +380,6 @@ def augment_messages_for_tools_llama_3_1( def augment_messages_for_tools_llama_3_2( request: ChatCompletionRequest, ) -> List[Message]: - assert request.tool_config.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" - existing_messages = request.messages existing_system_message = None if existing_messages[0].role == Role.system.value: @@ -385,7 +387,6 @@ def augment_messages_for_tools_llama_3_2( assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" - messages = [] sys_content = "" custom_tools, builtin_tools = [], [] for t in request.tools: @@ -394,7 +395,6 @@ def augment_messages_for_tools_llama_3_2( else: builtin_tools.append(t) - tool_template = None if builtin_tools: tool_gen = BuiltinToolGenerator() tool_template = tool_gen.gen(builtin_tools) @@ -422,8 +422,22 @@ def augment_messages_for_tools_llama_3_2( ): sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n") - messages.append(SystemMessage(content=sys_content.strip("\n"))) + tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) + if tool_choice_prompt: + sys_content += "\n" + tool_choice_prompt - # Add back existing messages from the request - messages += existing_messages + messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages] return messages + + +def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str: + if tool_choice == ToolChoice.auto: + return "" + elif tool_choice == ToolChoice.required: + return "You MUST use one of the provided functions/tools to answer the user query." + elif tool_choice == ToolChoice.none: + # tools are already not passed in + return "" + else: + # specific tool + return f"You MUST use the tool `{tool_choice}` to answer the user query." diff --git a/llama_stack/providers/utils/kvstore/sqlite/config.py b/llama_stack/providers/utils/kvstore/sqlite/config.py index a616c90d0..6a8b0a7cf 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/config.py +++ b/llama_stack/providers/utils/kvstore/sqlite/config.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class SqliteControlPlaneConfig(BaseModel): diff --git a/llama_stack/providers/utils/telemetry/trace_protocol.py b/llama_stack/providers/utils/telemetry/trace_protocol.py index 80c58a2c7..924274c42 100644 --- a/llama_stack/providers/utils/telemetry/trace_protocol.py +++ b/llama_stack/providers/utils/telemetry/trace_protocol.py @@ -9,9 +9,10 @@ import inspect from functools import wraps from typing import Any, AsyncGenerator, Callable, Type, TypeVar -from llama_models.llama3.api.datatypes import Primitive from pydantic import BaseModel +from llama_stack.models.llama.datatypes import Primitive + T = TypeVar("T") diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py new file mode 100644 index 000000000..56b9e5e4c --- /dev/null +++ b/llama_stack/schema_utils.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from dataclasses import dataclass +from typing import Any, Callable, List, Optional, TypeVar + +from .strong_typing.schema import json_schema_type, register_schema # noqa: F401 + +T = TypeVar("T") + + +@dataclass +class WebMethod: + route: Optional[str] = None + public: bool = False + request_examples: Optional[List[Any]] = None + response_examples: Optional[List[Any]] = None + method: Optional[str] = None + + +def webmethod( + route: Optional[str] = None, + method: Optional[str] = None, + public: Optional[bool] = False, + request_examples: Optional[List[Any]] = None, + response_examples: Optional[List[Any]] = None, +) -> Callable[[T], T]: + """ + Decorator that supplies additional metadata to an endpoint operation function. + + :param route: The URL path pattern associated with this operation which path parameters are substituted into. + :param public: True if the operation can be invoked without prior authentication. + :param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON. + :param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON. + """ + + def wrap(cls: T) -> T: + cls.__webmethod__ = WebMethod( + route=route, + method=method, + public=public or False, + request_examples=request_examples, + response_examples=response_examples, + ) + return cls + + return wrap diff --git a/llama_stack/scripts/generate_prompt_format.py b/llama_stack/scripts/generate_prompt_format.py new file mode 100644 index 000000000..ecdde900f --- /dev/null +++ b/llama_stack/scripts/generate_prompt_format.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import importlib +from pathlib import Path +from typing import Optional + +import fire + +# from llama_stack.models.llama.datatypes import * # noqa: F403 +from llama_models.llama3.reference_impl.generation import Llama + +THIS_DIR = Path(__file__).parent.resolve() + + +def run_main( + ckpt_dir: str, + module_name: str, + output_path: str, + model_parallel_size: Optional[int] = None, +): + module = importlib.import_module(module_name) + assert hasattr(module, "usecases"), f"Module {module_name} missing usecases function" + tokenizer_path = str(THIS_DIR.parent / "llama3/api/tokenizer.model") + generator = Llama.build( + ckpt_dir=ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=512, + max_batch_size=1, + model_parallel_size=model_parallel_size, + ) + + use_cases = module.usecases() + text = "" + for u in use_cases: + if isinstance(u, str): + use_case_text = f"\n{u}\n" + else: + use_case_text = u.to_text(generator) + + text += use_case_text + print(use_case_text) + + text += "Thank You!\n" + + with open(output_path, "w") as f: + f.write(text) + + +def main(): + fire.Fire(run_main) + + +if __name__ == "__main__": + main() diff --git a/docs/openapi_generator/strong_typing/__init__.py b/llama_stack/strong_typing/__init__.py similarity index 100% rename from docs/openapi_generator/strong_typing/__init__.py rename to llama_stack/strong_typing/__init__.py diff --git a/docs/openapi_generator/strong_typing/auxiliary.py b/llama_stack/strong_typing/auxiliary.py similarity index 93% rename from docs/openapi_generator/strong_typing/auxiliary.py rename to llama_stack/strong_typing/auxiliary.py index bfaec0d29..cf19d6083 100644 --- a/docs/openapi_generator/strong_typing/auxiliary.py +++ b/llama_stack/strong_typing/auxiliary.py @@ -13,7 +13,7 @@ Type-safe data interchange for Python data classes. import dataclasses import sys from dataclasses import is_dataclass -from typing import Callable, Dict, Optional, overload, Type, TypeVar, Union +from typing import Callable, Dict, Optional, Type, TypeVar, Union, overload if sys.version_info >= (3, 9): from typing import Annotated as Annotated @@ -42,9 +42,7 @@ def _compact_dataclass_repr(obj: object) -> str: """ if is_dataclass(obj): - arglist = ", ".join( - repr(getattr(obj, field.name)) for field in dataclasses.fields(obj) - ) + arglist = ", ".join(repr(getattr(obj, field.name)) for field in dataclasses.fields(obj)) return f"{obj.__class__.__name__}({arglist})" else: return obj.__class__.__name__ @@ -62,9 +60,7 @@ def typeannotation(cls: Type[T], /) -> Type[T]: ... @overload -def typeannotation( - cls: None, *, eq: bool = True, order: bool = False -) -> Callable[[Type[T]], Type[T]]: ... +def typeannotation(cls: None, *, eq: bool = True, order: bool = False) -> Callable[[Type[T]], Type[T]]: ... @dataclass_transform(eq_default=True, order_default=False) @@ -81,7 +77,7 @@ def typeannotation( """ def wrap(cls: Type[T]) -> Type[T]: - setattr(cls, "__repr__", _compact_dataclass_repr) + cls.__repr__ = _compact_dataclass_repr if not dataclasses.is_dataclass(cls): cls = dataclasses.dataclass( # type: ignore[call-overload] cls, diff --git a/docs/openapi_generator/strong_typing/classdef.py b/llama_stack/strong_typing/classdef.py similarity index 91% rename from docs/openapi_generator/strong_typing/classdef.py rename to llama_stack/strong_typing/classdef.py index b86940420..5ead886d4 100644 --- a/docs/openapi_generator/strong_typing/classdef.py +++ b/llama_stack/strong_typing/classdef.py @@ -22,13 +22,13 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Uni from .auxiliary import ( Alias, Annotated, + MaxLength, + Precision, float32, float64, int16, int32, int64, - MaxLength, - Precision, ) from .core import JsonType, Schema from .docstring import Docstring, DocstringParam @@ -181,17 +181,13 @@ def enum_values_to_type( # assign the newly created type to the same module where the defining class is enum_class.__module__ = module.__name__ - enum_class.__doc__ = str( - Docstring(short_description=title, long_description=description) - ) + enum_class.__doc__ = str(Docstring(short_description=title, long_description=description)) setattr(module, name, enum_class) return enum.unique(enum_class) -def schema_to_type( - schema: Schema, *, module: types.ModuleType, class_name: str -) -> TypeLike: +def schema_to_type(schema: Schema, *, module: types.ModuleType, class_name: str) -> TypeLike: """ Creates a Python type from a JSON schema. @@ -200,16 +196,14 @@ def schema_to_type( :param class_name: The name assigned to the top-level class. """ - top_node = typing.cast( - JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema) - ) + top_node = typing.cast(JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema)) if top_node.definitions is not None: for type_name, type_node in top_node.definitions.items(): type_def = node_to_typedef(module, type_name, type_node) if type_def.default is not dataclasses.MISSING: raise TypeError("disallowed: `default` for top-level type definitions") - setattr(type_def.type, "__module__", module.__name__) + type_def.type.__module__ = module.__name__ setattr(module, type_name, type_def.type) return node_to_typedef(module, class_name, top_node).type @@ -228,9 +222,7 @@ def json_to_value(target_type: TypeLike, data: JsonType) -> Any: return dataclasses.MISSING -def node_to_typedef( - module: types.ModuleType, context: str, node: JsonSchemaNode -) -> TypeDef: +def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode) -> TypeDef: if isinstance(node, JsonSchemaRef): match_obj = re.match(r"^#/definitions/(\w+)$", node.ref) if not match_obj: @@ -360,22 +352,16 @@ def node_to_typedef( prop_type = type_def.type else: prop_type = Union[(None, type_def.type)] - fields.append( - (prop_name, prop_type, dataclasses.field(default=type_def.default)) - ) + fields.append((prop_name, prop_type, dataclasses.field(default=type_def.default))) prop_desc = prop_node.title or prop_node.description if prop_desc is not None: params[prop_name] = DocstringParam(prop_name, prop_desc) fields.sort(key=lambda t: t[2].default is not dataclasses.MISSING) if sys.version_info >= (3, 12): - class_type = dataclasses.make_dataclass( - class_name, fields, module=module.__name__ - ) + class_type = dataclasses.make_dataclass(class_name, fields, module=module.__name__) else: - class_type = dataclasses.make_dataclass( - class_name, fields, namespace={"__module__": module.__name__} - ) + class_type = dataclasses.make_dataclass(class_name, fields, namespace={"__module__": module.__name__}) class_type.__doc__ = str( Docstring( short_description=node.title, @@ -402,12 +388,8 @@ class SchemaFlatteningOptions: recursive: bool = False -def flatten_schema( - schema: Schema, *, options: Optional[SchemaFlatteningOptions] = None -) -> Schema: - top_node = typing.cast( - JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema) - ) +def flatten_schema(schema: Schema, *, options: Optional[SchemaFlatteningOptions] = None) -> Schema: + top_node = typing.cast(JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema)) flattener = SchemaFlattener(options) obj = flattener.flatten(top_node) return typing.cast(Schema, object_to_json(obj)) @@ -442,9 +424,7 @@ class SchemaFlattener: obj = prop if obj.properties is not None: if self.options.qualified_names: - target_props.update( - (f"{name}.{n}", p) for n, p in obj.properties.items() - ) + target_props.update((f"{name}.{n}", p) for n, p in obj.properties.items()) else: target_props.update(obj.properties.items()) if obj.required is not None: diff --git a/docs/openapi_generator/strong_typing/core.py b/llama_stack/strong_typing/core.py similarity index 100% rename from docs/openapi_generator/strong_typing/core.py rename to llama_stack/strong_typing/core.py diff --git a/docs/openapi_generator/strong_typing/deserializer.py b/llama_stack/strong_typing/deserializer.py similarity index 84% rename from docs/openapi_generator/strong_typing/deserializer.py rename to llama_stack/strong_typing/deserializer.py index 5859d3bbe..fc0f40f83 100644 --- a/docs/openapi_generator/strong_typing/deserializer.py +++ b/llama_stack/strong_typing/deserializer.py @@ -40,6 +40,7 @@ from typing import ( from .core import JsonType from .exception import JsonKeyError, JsonTypeError, JsonValueError from .inspection import ( + TypeLike, create_object, enum_value_types, evaluate_type, @@ -52,7 +53,6 @@ from .inspection import ( is_type_annotated, is_type_literal, is_type_optional, - TypeLike, unwrap_annotated_type, unwrap_literal_values, unwrap_optional_type, @@ -92,9 +92,7 @@ class NoneDeserializer(Deserializer[None]): def parse(self, data: JsonType) -> None: if data is not None: - raise JsonTypeError( - f"`None` type expects JSON `null` but instead received: {data}" - ) + raise JsonTypeError(f"`None` type expects JSON `null` but instead received: {data}") return None @@ -103,9 +101,7 @@ class BoolDeserializer(Deserializer[bool]): def parse(self, data: JsonType) -> bool: if not isinstance(data, bool): - raise JsonTypeError( - f"`bool` type expects JSON `boolean` data but instead received: {data}" - ) + raise JsonTypeError(f"`bool` type expects JSON `boolean` data but instead received: {data}") return bool(data) @@ -114,9 +110,7 @@ class IntDeserializer(Deserializer[int]): def parse(self, data: JsonType) -> int: if not isinstance(data, int): - raise JsonTypeError( - f"`int` type expects integer data as JSON `number` but instead received: {data}" - ) + raise JsonTypeError(f"`int` type expects integer data as JSON `number` but instead received: {data}") return int(data) @@ -125,9 +119,7 @@ class FloatDeserializer(Deserializer[float]): def parse(self, data: JsonType) -> float: if not isinstance(data, float) and not isinstance(data, int): - raise JsonTypeError( - f"`int` type expects data as JSON `number` but instead received: {data}" - ) + raise JsonTypeError(f"`int` type expects data as JSON `number` but instead received: {data}") return float(data) @@ -136,9 +128,7 @@ class StringDeserializer(Deserializer[str]): def parse(self, data: JsonType) -> str: if not isinstance(data, str): - raise JsonTypeError( - f"`str` type expects JSON `string` data but instead received: {data}" - ) + raise JsonTypeError(f"`str` type expects JSON `string` data but instead received: {data}") return str(data) @@ -147,9 +137,7 @@ class BytesDeserializer(Deserializer[bytes]): def parse(self, data: JsonType) -> bytes: if not isinstance(data, str): - raise JsonTypeError( - f"`bytes` type expects JSON `string` data but instead received: {data}" - ) + raise JsonTypeError(f"`bytes` type expects JSON `string` data but instead received: {data}") return base64.b64decode(data, validate=True) @@ -158,17 +146,13 @@ class DateTimeDeserializer(Deserializer[datetime.datetime]): def parse(self, data: JsonType) -> datetime.datetime: if not isinstance(data, str): - raise JsonTypeError( - f"`datetime` type expects JSON `string` data but instead received: {data}" - ) + raise JsonTypeError(f"`datetime` type expects JSON `string` data but instead received: {data}") if data.endswith("Z"): data = f"{data[:-1]}+00:00" # Python's isoformat() does not support military time zones like "Zulu" for UTC timestamp = datetime.datetime.fromisoformat(data) if timestamp.tzinfo is None: - raise JsonValueError( - f"timestamp lacks explicit time zone designator: {data}" - ) + raise JsonValueError(f"timestamp lacks explicit time zone designator: {data}") return timestamp @@ -177,9 +161,7 @@ class DateDeserializer(Deserializer[datetime.date]): def parse(self, data: JsonType) -> datetime.date: if not isinstance(data, str): - raise JsonTypeError( - f"`date` type expects JSON `string` data but instead received: {data}" - ) + raise JsonTypeError(f"`date` type expects JSON `string` data but instead received: {data}") return datetime.date.fromisoformat(data) @@ -189,9 +171,7 @@ class TimeDeserializer(Deserializer[datetime.time]): def parse(self, data: JsonType) -> datetime.time: if not isinstance(data, str): - raise JsonTypeError( - f"`time` type expects JSON `string` data but instead received: {data}" - ) + raise JsonTypeError(f"`time` type expects JSON `string` data but instead received: {data}") return datetime.time.fromisoformat(data) @@ -201,9 +181,7 @@ class UUIDDeserializer(Deserializer[uuid.UUID]): def parse(self, data: JsonType) -> uuid.UUID: if not isinstance(data, str): - raise JsonTypeError( - f"`UUID` type expects JSON `string` data but instead received: {data}" - ) + raise JsonTypeError(f"`UUID` type expects JSON `string` data but instead received: {data}") return uuid.UUID(data) @@ -212,9 +190,7 @@ class IPv4Deserializer(Deserializer[ipaddress.IPv4Address]): def parse(self, data: JsonType) -> ipaddress.IPv4Address: if not isinstance(data, str): - raise JsonTypeError( - f"`IPv4Address` type expects JSON `string` data but instead received: {data}" - ) + raise JsonTypeError(f"`IPv4Address` type expects JSON `string` data but instead received: {data}") return ipaddress.IPv4Address(data) @@ -223,9 +199,7 @@ class IPv6Deserializer(Deserializer[ipaddress.IPv6Address]): def parse(self, data: JsonType) -> ipaddress.IPv6Address: if not isinstance(data, str): - raise JsonTypeError( - f"`IPv6Address` type expects JSON `string` data but instead received: {data}" - ) + raise JsonTypeError(f"`IPv6Address` type expects JSON `string` data but instead received: {data}") return ipaddress.IPv6Address(data) @@ -244,9 +218,7 @@ class ListDeserializer(Deserializer[List[T]]): def parse(self, data: JsonType) -> List[T]: if not isinstance(data, list): type_name = python_type_to_str(self.item_type) - raise JsonTypeError( - f"type `List[{type_name}]` expects JSON `array` data but instead received: {data}" - ) + raise JsonTypeError(f"type `List[{type_name}]` expects JSON `array` data but instead received: {data}") return [self.item_parser.parse(item) for item in data] @@ -319,9 +291,7 @@ class SetDeserializer(Deserializer[Set[T]]): def parse(self, data: JsonType) -> Set[T]: if not isinstance(data, list): type_name = python_type_to_str(self.member_type) - raise JsonTypeError( - f"type `Set[{type_name}]` expects JSON `array` data but instead received: {data}" - ) + raise JsonTypeError(f"type `Set[{type_name}]` expects JSON `array` data but instead received: {data}") return set(self.member_parser.parse(item) for item in data) @@ -336,15 +306,11 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]): self.item_types = item_types def build(self, context: Optional[ModuleType]) -> None: - self.item_parsers = tuple( - _get_deserializer(item_type, context) for item_type in self.item_types - ) + self.item_parsers = tuple(_get_deserializer(item_type, context) for item_type in self.item_types) @property def container_type(self) -> str: - type_names = ", ".join( - python_type_to_str(item_type) for item_type in self.item_types - ) + type_names = ", ".join(python_type_to_str(item_type) for item_type in self.item_types) return f"Tuple[{type_names}]" def parse(self, data: JsonType) -> Tuple[Any, ...]: @@ -359,10 +325,7 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]): f"type `{self.container_type}` expects a JSON `array` of length {count} but received length {len(data)}" ) - return tuple( - item_parser.parse(item) - for item_parser, item in zip(self.item_parsers, data) - ) + return tuple(item_parser.parse(item) for item_parser, item in zip(self.item_parsers, data, strict=False)) class UnionDeserializer(Deserializer): @@ -375,9 +338,7 @@ class UnionDeserializer(Deserializer): self.member_types = member_types def build(self, context: Optional[ModuleType]) -> None: - self.member_parsers = tuple( - _get_deserializer(member_type, context) for member_type in self.member_types - ) + self.member_parsers = tuple(_get_deserializer(member_type, context) for member_type in self.member_types) def parse(self, data: JsonType) -> Any: for member_parser in self.member_parsers: @@ -389,21 +350,15 @@ class UnionDeserializer(Deserializer): # i.e. we don't have the type that we are looking for continue - type_names = ", ".join( - python_type_to_str(member_type) for member_type in self.member_types - ) - raise JsonKeyError( - f"type `Union[{type_names}]` could not be instantiated from: {data}" - ) + type_names = ", ".join(python_type_to_str(member_type) for member_type in self.member_types) + raise JsonKeyError(f"type `Union[{type_names}]` could not be instantiated from: {data}") def get_literal_properties(typ: type) -> Set[str]: "Returns the names of all properties in a class that are of a literal type." return set( - property_name - for property_name, property_type in get_class_properties(typ) - if is_type_literal(property_type) + property_name for property_name, property_type in get_class_properties(typ) if is_type_literal(property_type) ) @@ -450,9 +405,7 @@ class TaggedUnionDeserializer(Deserializer): @property def union_type(self) -> str: - type_names = ", ".join( - python_type_to_str(member_type) for member_type in self.member_types - ) + type_names = ", ".join(python_type_to_str(member_type) for member_type in self.member_types) return f"Union[{type_names}]" def parse(self, data: JsonType) -> Any: @@ -466,9 +419,7 @@ class TaggedUnionDeserializer(Deserializer): if disambiguating_value is None: continue - member_parser = self.member_parsers.get( - (property_name, disambiguating_value) - ) + member_parser = self.member_parsers.get((property_name, disambiguating_value)) if member_parser is None: raise JsonTypeError( f"disambiguating property value is invalid for tagged union type `{self.union_type}`: {data}" @@ -506,9 +457,7 @@ class LiteralDeserializer(Deserializer): value = self.parser.parse(data) if value not in self.values: value_names = ", ".join(repr(value) for value in self.values) - raise JsonTypeError( - f"type `Literal[{value_names}]` could not be instantiated from: {data}" - ) + raise JsonTypeError(f"type `Literal[{value_names}]` could not be instantiated from: {data}") return value @@ -549,9 +498,7 @@ class FieldDeserializer(abc.ABC, Generic[T, R]): field_name: str parser: Deserializer[T] - def __init__( - self, property_name: str, field_name: str, parser: Deserializer[T] - ) -> None: + def __init__(self, property_name: str, field_name: str, parser: Deserializer[T]) -> None: self.property_name = property_name self.field_name = field_name self.parser = parser @@ -565,9 +512,7 @@ class RequiredFieldDeserializer(FieldDeserializer[T, T]): def parse_field(self, data: Dict[str, JsonType]) -> T: if self.property_name not in data: - raise JsonKeyError( - f"missing required property `{self.property_name}` from JSON object: {data}" - ) + raise JsonKeyError(f"missing required property `{self.property_name}` from JSON object: {data}") return self.parser.parse(data[self.property_name]) @@ -641,32 +586,22 @@ class ClassDeserializer(Deserializer[T]): def assign(self, property_parsers: List[FieldDeserializer]) -> None: self.property_parsers = property_parsers - self.property_fields = set( - property_parser.property_name for property_parser in property_parsers - ) + self.property_fields = set(property_parser.property_name for property_parser in property_parsers) def parse(self, data: JsonType) -> T: if not isinstance(data, dict): type_name = python_type_to_str(self.class_type) - raise JsonTypeError( - f"`type `{type_name}` expects JSON `object` data but instead received: {data}" - ) + raise JsonTypeError(f"`type `{type_name}` expects JSON `object` data but instead received: {data}") object_data: Dict[str, JsonType] = typing.cast(Dict[str, JsonType], data) field_values = {} for property_parser in self.property_parsers: - field_values[property_parser.field_name] = property_parser.parse_field( - object_data - ) + field_values[property_parser.field_name] = property_parser.parse_field(object_data) if not self.property_fields.issuperset(object_data): - unassigned_names = [ - name for name in object_data if name not in self.property_fields - ] - raise JsonKeyError( - f"unrecognized fields in JSON object: {unassigned_names}" - ) + unassigned_names = [name for name in object_data if name not in self.property_fields] + raise JsonKeyError(f"unrecognized fields in JSON object: {unassigned_names}") return self.create(**field_values) @@ -686,9 +621,7 @@ class NamedTupleDeserializer(ClassDeserializer[NamedTuple]): def build(self, context: Optional[ModuleType]) -> None: property_parsers: List[FieldDeserializer] = [ - RequiredFieldDeserializer( - field_name, field_name, _get_deserializer(field_type, context) - ) + RequiredFieldDeserializer(field_name, field_name, _get_deserializer(field_type, context)) for field_name, field_type in get_resolved_hints(self.class_type).items() ] super().assign(property_parsers) @@ -729,17 +662,11 @@ class DataclassDeserializer(ClassDeserializer[T]): ) elif has_default_factory: default_factory = typing.cast(Callable[[], Any], field.default_factory) - field_parser = DefaultFactoryFieldDeserializer( - property_name, field.name, parser, default_factory - ) + field_parser = DefaultFactoryFieldDeserializer(property_name, field.name, parser, default_factory) elif is_optional: - field_parser = OptionalFieldDeserializer( - property_name, field.name, parser - ) + field_parser = OptionalFieldDeserializer(property_name, field.name, parser) else: - field_parser = RequiredFieldDeserializer( - property_name, field.name, parser - ) + field_parser = RequiredFieldDeserializer(property_name, field.name, parser) property_parsers.append(field_parser) @@ -778,22 +705,16 @@ class TypedClassDeserializer(ClassDeserializer[T]): parser = _get_deserializer(required_type, context) if is_optional: - field_parser: FieldDeserializer = OptionalFieldDeserializer( - property_name, field_name, parser - ) + field_parser: FieldDeserializer = OptionalFieldDeserializer(property_name, field_name, parser) else: - field_parser = RequiredFieldDeserializer( - property_name, field_name, parser - ) + field_parser = RequiredFieldDeserializer(property_name, field_name, parser) property_parsers.append(field_parser) super().assign(property_parsers) -def create_deserializer( - typ: TypeLike, context: Optional[ModuleType] = None -) -> Deserializer: +def create_deserializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Deserializer: """ Creates a de-serializer engine to produce a Python object from an object obtained from a JSON string. @@ -900,15 +821,11 @@ def _create_deserializer(typ: TypeLike) -> Deserializer: if typ is list: raise TypeError("explicit item type required: use `List[T]` instead of `list`") if typ is dict: - raise TypeError( - "explicit key and value types required: use `Dict[K, V]` instead of `dict`" - ) + raise TypeError("explicit key and value types required: use `Dict[K, V]` instead of `dict`") if typ is set: raise TypeError("explicit member type required: use `Set[T]` instead of `set`") if typ is tuple: - raise TypeError( - "explicit item type list required: use `Tuple[T, ...]` instead of `tuple`" - ) + raise TypeError("explicit item type list required: use `Tuple[T, ...]` instead of `tuple`") # generic types (e.g. list, dict, set, etc.) origin_type = typing.get_origin(typ) diff --git a/docs/openapi_generator/strong_typing/docstring.py b/llama_stack/strong_typing/docstring.py similarity index 84% rename from docs/openapi_generator/strong_typing/docstring.py rename to llama_stack/strong_typing/docstring.py index 3ef1e5e7a..9169aadfe 100644 --- a/docs/openapi_generator/strong_typing/docstring.py +++ b/llama_stack/strong_typing/docstring.py @@ -164,10 +164,7 @@ def is_exception(member: object) -> TypeGuard[Type[BaseException]]: def get_exceptions(module: types.ModuleType) -> Dict[str, Type[BaseException]]: "Returns all exception classes declared in a module." - return { - name: class_type - for name, class_type in inspect.getmembers(module, is_exception) - } + return {name: class_type for name, class_type in inspect.getmembers(module, is_exception)} class SupportsDoc(Protocol): @@ -212,11 +209,7 @@ def parse_type(typ: SupportsDoc) -> Docstring: for exc_name, exc in docstring.raises.items(): raise_type = context.get(exc_name) if raise_type is None: - type_name = ( - getattr(typ, "__qualname__", None) - or getattr(typ, "__name__", None) - or None - ) + type_name = getattr(typ, "__qualname__", None) or getattr(typ, "__name__", None) or None raise TypeError( f"doc-string exception type `{exc_name}` is not an exception defined in the context of `{type_name}`" ) @@ -262,9 +255,7 @@ def parse_text(text: str) -> Docstring: params: Dict[str, DocstringParam] = {} raises: Dict[str, DocstringRaises] = {} returns = None - for match in re.finditer( - r"(^:.*?)(?=^:|\Z)", meta_chunk, flags=re.DOTALL | re.MULTILINE - ): + for match in re.finditer(r"(^:.*?)(?=^:|\Z)", meta_chunk, flags=re.DOTALL | re.MULTILINE): chunk = match.group(0) if not chunk: continue @@ -307,11 +298,7 @@ def has_default_docstring(typ: SupportsDoc) -> bool: return False if is_dataclass_type(typ): - return ( - typ.__doc__ is not None - and re.match(f"^{re.escape(typ.__name__)}[(].*[)]$", typ.__doc__) - is not None - ) + return typ.__doc__ is not None and re.match(f"^{re.escape(typ.__name__)}[(].*[)]$", typ.__doc__) is not None if is_type_enum(typ): return typ.__doc__ is not None and typ.__doc__ == "An enumeration." @@ -338,9 +325,7 @@ def get_docstring(typ: SupportsDoc) -> Optional[str]: return typ.__doc__ -def check_docstring( - typ: SupportsDoc, docstring: Docstring, strict: bool = False -) -> None: +def check_docstring(typ: SupportsDoc, docstring: Docstring, strict: bool = False) -> None: """ Verifies the doc-string of a type. @@ -353,9 +338,7 @@ def check_docstring( check_function_docstring(typ, docstring, strict) -def check_dataclass_docstring( - typ: Type[DataclassInstance], docstring: Docstring, strict: bool = False -) -> None: +def check_dataclass_docstring(typ: Type[DataclassInstance], docstring: Docstring, strict: bool = False) -> None: """ Verifies the doc-string of a data-class type. @@ -371,23 +354,17 @@ def check_dataclass_docstring( for name in docstring.params: if name not in properties: - raise TypeError( - f"doc-string parameter `{name}` is not a member of the data-class `{class_name}`" - ) + raise TypeError(f"doc-string parameter `{name}` is not a member of the data-class `{class_name}`") if not strict: return for name in properties: if name not in docstring.params: - raise TypeError( - f"member `{name}` in data-class `{class_name}` is missing its doc-string" - ) + raise TypeError(f"member `{name}` in data-class `{class_name}` is missing its doc-string") -def check_function_docstring( - fn: Callable[..., Any], docstring: Docstring, strict: bool = False -) -> None: +def check_function_docstring(fn: Callable[..., Any], docstring: Docstring, strict: bool = False) -> None: """ Verifies the doc-string of a function or member function. @@ -400,17 +377,10 @@ def check_function_docstring( for name in docstring.params: if name not in signature.parameters: - raise TypeError( - f"doc-string parameter `{name}` is absent from signature of function `{func_name}`" - ) + raise TypeError(f"doc-string parameter `{name}` is absent from signature of function `{func_name}`") - if ( - docstring.returns is not None - and signature.return_annotation is inspect.Signature.empty - ): - raise TypeError( - f"doc-string has returns description in function `{func_name}` with no return type annotation" - ) + if docstring.returns is not None and signature.return_annotation is inspect.Signature.empty: + raise TypeError(f"doc-string has returns description in function `{func_name}` with no return type annotation") if not strict: return @@ -418,20 +388,12 @@ def check_function_docstring( for name, param in signature.parameters.items(): # ignore `self` in member function signatures if name == "self" and ( - param.kind is inspect.Parameter.POSITIONAL_ONLY - or param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD + param.kind is inspect.Parameter.POSITIONAL_ONLY or param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD ): continue if name not in docstring.params: - raise TypeError( - f"function parameter `{name}` in `{func_name}` is missing its doc-string" - ) + raise TypeError(f"function parameter `{name}` in `{func_name}` is missing its doc-string") - if ( - signature.return_annotation is not inspect.Signature.empty - and docstring.returns is None - ): - raise TypeError( - f"function `{func_name}` has no returns description in its doc-string" - ) + if signature.return_annotation is not inspect.Signature.empty and docstring.returns is None: + raise TypeError(f"function `{func_name}` has no returns description in its doc-string") diff --git a/docs/openapi_generator/strong_typing/exception.py b/llama_stack/strong_typing/exception.py similarity index 100% rename from docs/openapi_generator/strong_typing/exception.py rename to llama_stack/strong_typing/exception.py diff --git a/docs/openapi_generator/strong_typing/inspection.py b/llama_stack/strong_typing/inspection.py similarity index 96% rename from docs/openapi_generator/strong_typing/inspection.py rename to llama_stack/strong_typing/inspection.py index 41804f12c..8bc313021 100644 --- a/docs/openapi_generator/strong_typing/inspection.py +++ b/llama_stack/strong_typing/inspection.py @@ -32,12 +32,12 @@ from typing import ( NamedTuple, Optional, Protocol, - runtime_checkable, Set, Tuple, Type, TypeVar, Union, + runtime_checkable, ) if sys.version_info >= (3, 9): @@ -161,9 +161,7 @@ class DataclassField: type: Any default: Any - def __init__( - self, name: str, type: Any, default: Any = dataclasses.MISSING - ) -> None: + def __init__(self, name: str, type: Any, default: Any = dataclasses.MISSING) -> None: self.name = name self.type = type self.default = default @@ -173,9 +171,7 @@ def dataclass_fields(cls: Type[DataclassInstance]) -> Iterable[DataclassField]: "Generates the fields of a data-class resolving forward references." for field in dataclasses.fields(cls): - yield DataclassField( - field.name, evaluate_member_type(field.type, cls), field.default - ) + yield DataclassField(field.name, evaluate_member_type(field.type, cls), field.default) def dataclass_field_by_name(cls: Type[DataclassInstance], name: str) -> DataclassField: @@ -267,8 +263,8 @@ def extend_enum( enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore # assign the newly created type to the same module where the extending class is defined - setattr(enum_class, "__module__", extend.__module__) - setattr(enum_class, "__doc__", extend.__doc__) + enum_class.__module__ = extend.__module__ + enum_class.__doc__ = extend.__doc__ setattr(sys.modules[extend.__module__], extend.__name__, enum_class) return enum.unique(enum_class) @@ -291,9 +287,7 @@ else: return typing.get_origin(typ) is Union -def is_type_optional( - typ: object, strict: bool = False -) -> TypeGuard[Type[Optional[Any]]]: +def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[Type[Optional[Any]]]: """ True if the type annotation corresponds to an optional type (e.g. `Optional[T]` or `Union[T1,T2,None]`). @@ -525,9 +519,7 @@ def unwrap_annotated_type(typ: T) -> T: return typ -def rewrap_annotated_type( - transform: Callable[[Type[S]], Type[T]], typ: Type[S] -) -> Type[T]: +def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S]) -> Type[T]: """ Un-boxes, transforms and re-boxes an optionally annotated type. @@ -595,9 +587,7 @@ class _ROOT: pass -def get_referenced_types( - typ: TypeLike, module: Optional[types.ModuleType] = None -) -> Set[type]: +def get_referenced_types(typ: TypeLike, module: Optional[types.ModuleType] = None) -> Set[type]: """ Extracts types directly or indirectly referenced by this type. @@ -867,8 +857,7 @@ def is_generic_instance(obj: Any, typ: TypeLike) -> bool: key_type, value_type = typing.get_args(typ) dict_obj: dict = obj return all( - is_generic_instance(key, key_type) - and is_generic_instance(value, value_type) + is_generic_instance(key, key_type) and is_generic_instance(value, value_type) for key, value in dict_obj.items() ) elif origin_type is set: @@ -885,13 +874,11 @@ def is_generic_instance(obj: Any, typ: TypeLike) -> bool: for tuple_item_type, item in zip( (tuple_item_type for tuple_item_type in typing.get_args(typ)), (item for item in obj), + strict=False, ) ) elif origin_type is Union: - return any( - is_generic_instance(obj, member_type) - for member_type in typing.get_args(typ) - ) + return any(is_generic_instance(obj, member_type) for member_type in typing.get_args(typ)) elif isinstance(typ, type): return isinstance(obj, typ) else: @@ -968,6 +955,7 @@ class RecursiveChecker: for tuple_item_type, item in zip( (tuple_item_type for tuple_item_type in typing.get_args(typ)), (item for item in obj), + strict=False, ) ) elif origin_type is Union: @@ -995,8 +983,7 @@ class RecursiveChecker: raise TypeError(f"expected `{typ}` but got: {obj}") resolved_hints = get_resolved_hints(typ) return all( - self.check(resolved_hints[field.name], getattr(obj, field.name)) - for field in dataclasses.fields(typ) + self.check(resolved_hints[field.name], getattr(obj, field.name)) for field in dataclasses.fields(typ) ) else: if not isinstance(obj, typ): @@ -1027,9 +1014,7 @@ def check_recursive( if type_pred is not None and value_pred is not None: if pred is not None: - raise TypeError( - "filter predicate not permitted when type and value predicates are present" - ) + raise TypeError("filter predicate not permitted when type and value predicates are present") type_p: Callable[[Type[T]], bool] = type_pred value_p: Callable[[T], bool] = value_pred @@ -1037,9 +1022,7 @@ def check_recursive( elif value_pred is not None: if pred is not None: - raise TypeError( - "filter predicate not permitted when value predicate is present" - ) + raise TypeError("filter predicate not permitted when value predicate is present") value_only_p: Callable[[T], bool] = value_pred pred = lambda typ, obj: value_only_p(obj) # noqa: E731 diff --git a/docs/openapi_generator/strong_typing/mapping.py b/llama_stack/strong_typing/mapping.py similarity index 91% rename from docs/openapi_generator/strong_typing/mapping.py rename to llama_stack/strong_typing/mapping.py index 2bc68bb63..408375a9f 100644 --- a/docs/openapi_generator/strong_typing/mapping.py +++ b/llama_stack/strong_typing/mapping.py @@ -17,9 +17,7 @@ from .auxiliary import Alias from .inspection import get_annotation -def python_field_to_json_property( - python_id: str, python_type: Optional[object] = None -) -> str: +def python_field_to_json_property(python_id: str, python_type: Optional[object] = None) -> str: """ Map a Python field identifier to a JSON property name. diff --git a/docs/openapi_generator/strong_typing/name.py b/llama_stack/strong_typing/name.py similarity index 94% rename from docs/openapi_generator/strong_typing/name.py rename to llama_stack/strong_typing/name.py index c883794c0..a1a2ae5f1 100644 --- a/docs/openapi_generator/strong_typing/name.py +++ b/llama_stack/strong_typing/name.py @@ -15,11 +15,11 @@ from typing import Any, Literal, Optional, Tuple, Union from .auxiliary import _auxiliary_types from .inspection import ( + TypeLike, is_generic_dict, is_generic_list, is_type_optional, is_type_union, - TypeLike, unwrap_generic_dict, unwrap_generic_list, unwrap_optional_type, @@ -110,17 +110,13 @@ class TypeFormatter: if arg is not auxiliary_arg: continue - auxiliary_metatuple: Optional[Tuple[Any, ...]] = getattr( - auxiliary_type, "__metadata__", None - ) + auxiliary_metatuple: Optional[Tuple[Any, ...]] = getattr(auxiliary_type, "__metadata__", None) if auxiliary_metatuple is None: continue if metaset.issuperset(auxiliary_metatuple): # type is an auxiliary type with extra annotations - auxiliary_args = ", ".join( - repr(m) for m in metatuple if m not in auxiliary_metatuple - ) + auxiliary_args = ", ".join(repr(m) for m in metatuple if m not in auxiliary_metatuple) return f"Annotated[{auxiliary_name}, {auxiliary_args}]" # type is an annotated type @@ -176,9 +172,7 @@ def python_type_to_name(data_type: TypeLike, force: bool = False) -> str: return f"Dict__{key_name}__{value_name}" elif is_type_union(data_type): member_types = unwrap_union_types(data_type) - member_names = "__".join( - python_type_to_name(member_type) for member_type in member_types - ) + member_names = "__".join(python_type_to_name(member_type) for member_type in member_types) return f"Union__{member_names}" # named system or user-defined type diff --git a/docs/openapi_generator/strong_typing/py.typed b/llama_stack/strong_typing/py.typed similarity index 100% rename from docs/openapi_generator/strong_typing/py.typed rename to llama_stack/strong_typing/py.typed diff --git a/docs/openapi_generator/strong_typing/schema.py b/llama_stack/strong_typing/schema.py similarity index 91% rename from docs/openapi_generator/strong_typing/schema.py rename to llama_stack/strong_typing/schema.py index 7f44435b8..dfc51ea78 100644 --- a/docs/openapi_generator/strong_typing/schema.py +++ b/llama_stack/strong_typing/schema.py @@ -28,11 +28,11 @@ from typing import ( List, Literal, Optional, - overload, Tuple, Type, TypeVar, Union, + overload, ) import jsonschema @@ -41,21 +41,21 @@ from typing_extensions import Annotated from . import docstring from .auxiliary import ( Alias, - get_auxiliary_format, IntegerRange, MaxLength, MinLength, Precision, + get_auxiliary_format, ) from .core import JsonArray, JsonObject, JsonType, Schema, StrictJsonType from .inspection import ( + TypeLike, enum_value_types, get_annotation, get_class_properties, is_type_enum, is_type_like, is_type_optional, - TypeLike, unwrap_optional_type, ) from .name import python_type_to_name @@ -108,7 +108,9 @@ def get_class_property_docstrings( def docstring_to_schema(data_type: type) -> Schema: short_description, long_description = get_class_docstrings(data_type) - schema: Schema = {} + schema: Schema = { + "title": python_type_to_name(data_type), + } description = "\n".join(filter(None, [short_description, long_description])) if description: @@ -240,17 +242,13 @@ class JsonSchemaGenerator: def _(self, arg: MaxLength) -> Schema: return {"maxLength": arg.value} - def _with_metadata( - self, type_schema: Schema, metadata: Optional[Tuple[Any, ...]] - ) -> Schema: + def _with_metadata(self, type_schema: Schema, metadata: Optional[Tuple[Any, ...]]) -> Schema: if metadata: for m in metadata: type_schema.update(self._metadata_to_schema(m)) return type_schema - def _simple_type_to_schema( - self, typ: TypeLike, json_schema_extra: Optional[dict] = None - ) -> Optional[Schema]: + def _simple_type_to_schema(self, typ: TypeLike, json_schema_extra: Optional[dict] = None) -> Optional[Schema]: """ Returns the JSON schema associated with a simple, unrestricted type. @@ -315,6 +313,17 @@ class JsonSchemaGenerator: data_type: TypeLike, force_expand: bool = False, json_schema_extra: Optional[dict] = None, + ) -> Schema: + common_info = {} + if json_schema_extra and "deprecated" in json_schema_extra: + common_info["deprecated"] = json_schema_extra["deprecated"] + return self._type_to_schema(data_type, force_expand, json_schema_extra) | common_info + + def _type_to_schema( + self, + data_type: TypeLike, + force_expand: bool = False, + json_schema_extra: Optional[dict] = None, ) -> Schema: """ Returns the JSON schema associated with a type. @@ -379,12 +388,7 @@ class JsonSchemaGenerator: enum_value_type = value_types.pop() enum_schema: Schema - if ( - enum_value_type is bool - or enum_value_type is int - or enum_value_type is float - or enum_value_type is str - ): + if enum_value_type is bool or enum_value_type is int or enum_value_type is float or enum_value_type is str: if enum_value_type is bool: enum_schema_type = "boolean" elif enum_value_type is int: @@ -414,9 +418,7 @@ class JsonSchemaGenerator: elif origin_type is dict: key_type, value_type = typing.get_args(typ) if not (key_type is str or key_type is int or is_type_enum(key_type)): - raise ValueError( - "`dict` with key type not coercible to `str` is not supported" - ) + raise ValueError("`dict` with key type not coercible to `str` is not supported") dict_schema: Schema value_schema = self.type_to_schema(value_type) @@ -424,9 +426,7 @@ class JsonSchemaGenerator: enum_values = [str(e.value) for e in key_type] if len(enum_values) > OBJECT_ENUM_EXPANSION_LIMIT: dict_schema = { - "propertyNames": { - "pattern": "^(" + "|".join(enum_values) + ")$" - }, + "propertyNames": {"pattern": "^(" + "|".join(enum_values) + ")$"}, "additionalProperties": value_schema, } else: @@ -453,30 +453,19 @@ class JsonSchemaGenerator: "type": "array", "minItems": len(args), "maxItems": len(args), - "prefixItems": [ - self.type_to_schema(member_type) for member_type in args - ], + "prefixItems": [self.type_to_schema(member_type) for member_type in args], } elif origin_type is Union: discriminator = None if typing.get_origin(data_type) is Annotated: discriminator = typing.get_args(data_type)[1].discriminator - ret = { - "oneOf": [ - self.type_to_schema(union_type) - for union_type in typing.get_args(typ) - ] - } + ret = {"oneOf": [self.type_to_schema(union_type) for union_type in typing.get_args(typ)]} if discriminator: # for each union type, we need to read the value of the discriminator mapping = {} for union_type in typing.get_args(typ): - props = self.type_to_schema(union_type, force_expand=True)[ - "properties" - ] - mapping[props[discriminator]["default"]] = self.type_to_schema( - union_type - )["$ref"] + props = self.type_to_schema(union_type, force_expand=True)["properties"] + mapping[props[discriminator]["default"]] = self.type_to_schema(union_type)["$ref"] ret["discriminator"] = { "propertyName": discriminator, @@ -495,9 +484,7 @@ class JsonSchemaGenerator: # dictionary of class attributes members = dict(inspect.getmembers(typ, lambda a: not inspect.isroutine(a))) - property_docstrings = get_class_property_docstrings( - typ, self.options.property_description_fun - ) + property_docstrings = get_class_property_docstrings(typ, self.options.property_description_fun) properties: Dict[str, Schema] = {} required: List[str] = [] for property_name, property_type in get_class_properties(typ): @@ -513,17 +500,17 @@ class JsonSchemaGenerator: if "model_fields" in members: f = members["model_fields"] defaults = {k: finfo.default for k, finfo in f.items()} - json_schema_extra = f.get(output_name, None).json_schema_extra + if output_name in f: + finfo = f[output_name] + json_schema_extra = finfo.json_schema_extra or {} + if finfo.deprecated: + json_schema_extra["deprecated"] = True if is_type_optional(property_type): optional_type: type = unwrap_optional_type(property_type) - property_def = self.type_to_schema( - optional_type, json_schema_extra=json_schema_extra - ) + property_def = self.type_to_schema(optional_type, json_schema_extra=json_schema_extra) else: - property_def = self.type_to_schema( - property_type, json_schema_extra=json_schema_extra - ) + property_def = self.type_to_schema(property_type, json_schema_extra=json_schema_extra) required.append(output_name) # check if attribute has a default value initializer @@ -580,9 +567,7 @@ class JsonSchemaGenerator: # add descriptive text (if present) if self.options.use_descriptions: - if isinstance(data_type, type) and not isinstance( - data_type, typing.ForwardRef - ): + if isinstance(data_type, type) and not isinstance(data_type, typing.ForwardRef): type_schema.update(docstring_to_schema(data_type)) # add example (if present) @@ -591,9 +576,7 @@ class JsonSchemaGenerator: return type_schema - def classdef_to_schema( - self, data_type: TypeLike, force_expand: bool = False - ) -> Tuple[Schema, Dict[str, Schema]]: + def classdef_to_schema(self, data_type: TypeLike, force_expand: bool = False) -> Tuple[Schema, Dict[str, Schema]]: """ Returns the JSON schema associated with a type and any nested types. @@ -668,9 +651,7 @@ def classdef_to_schema( try: validator.value.check_schema(class_schema) except jsonschema.exceptions.SchemaError: - raise TypeError( - f"schema does not validate against meta-schema <{validator_id}>" - ) + raise TypeError(f"schema does not validate against meta-schema <{validator_id}>") schema = {"$schema": validator_id} schema.update(class_schema) @@ -687,9 +668,7 @@ def validate_object(data_type: TypeLike, json_dict: JsonType) -> None: """ schema_dict = classdef_to_schema(data_type) - jsonschema.validate( - json_dict, schema_dict, format_checker=jsonschema.FormatChecker() - ) + jsonschema.validate(json_dict, schema_dict, format_checker=jsonschema.FormatChecker()) def print_schema(data_type: type) -> None: @@ -735,9 +714,7 @@ def json_schema_type(cls: Type[T], /) -> Type[T]: ... @overload -def json_schema_type( - cls: None, *, schema: Optional[Schema] = None -) -> Callable[[Type[T]], Type[T]]: ... +def json_schema_type(cls: None, *, schema: Optional[Schema] = None) -> Callable[[Type[T]], Type[T]]: ... def json_schema_type( diff --git a/docs/openapi_generator/strong_typing/serialization.py b/llama_stack/strong_typing/serialization.py similarity index 93% rename from docs/openapi_generator/strong_typing/serialization.py rename to llama_stack/strong_typing/serialization.py index 88d8fccad..c00a0aad5 100644 --- a/docs/openapi_generator/strong_typing/serialization.py +++ b/llama_stack/strong_typing/serialization.py @@ -42,9 +42,7 @@ def object_to_json(obj: Any) -> JsonType: return generator.generate(obj) -def json_to_object( - typ: TypeLike, data: JsonType, *, context: Optional[ModuleType] = None -) -> object: +def json_to_object(typ: TypeLike, data: JsonType, *, context: Optional[ModuleType] = None) -> object: """ Creates an object from a representation that has been de-serialized from JSON. @@ -85,9 +83,7 @@ def json_to_object( def json_dump_string(json_object: JsonType) -> str: "Dump an object as a JSON string with a compact representation." - return json.dumps( - json_object, ensure_ascii=False, check_circular=False, separators=(",", ":") - ) + return json.dumps(json_object, ensure_ascii=False, check_circular=False, separators=(",", ":")) def json_dump(json_object: JsonType, file: TextIO) -> None: diff --git a/docs/openapi_generator/strong_typing/serializer.py b/llama_stack/strong_typing/serializer.py similarity index 92% rename from docs/openapi_generator/strong_typing/serializer.py rename to llama_stack/strong_typing/serializer.py index f1252e374..4ca4a4119 100644 --- a/docs/openapi_generator/strong_typing/serializer.py +++ b/llama_stack/strong_typing/serializer.py @@ -40,6 +40,7 @@ from typing import ( from .core import JsonType from .exception import JsonTypeError, JsonValueError from .inspection import ( + TypeLike, enum_value_types, evaluate_type, get_class_properties, @@ -49,7 +50,6 @@ from .inspection import ( is_reserved_property, is_type_annotated, is_type_enum, - TypeLike, unwrap_annotated_type, ) from .mapping import python_field_to_json_property @@ -100,9 +100,7 @@ class BytesSerializer(Serializer[bytes]): class DateTimeSerializer(Serializer[datetime.datetime]): def generate(self, obj: datetime.datetime) -> str: if obj.tzinfo is None: - raise JsonValueError( - f"timestamp lacks explicit time zone designator: {obj}" - ) + raise JsonValueError(f"timestamp lacks explicit time zone designator: {obj}") fmt = obj.isoformat() if fmt.endswith("+00:00"): fmt = f"{fmt[:-6]}Z" # Python's isoformat() does not support military time zones like "Zulu" for UTC @@ -147,9 +145,7 @@ class UntypedListSerializer(Serializer[list]): class UntypedDictSerializer(Serializer[dict]): def generate(self, obj: dict) -> Dict[str, JsonType]: if obj and isinstance(next(iter(obj.keys())), enum.Enum): - iterator = ( - (key.value, object_to_json(value)) for key, value in obj.items() - ) + iterator = ((key.value, object_to_json(value)) for key, value in obj.items()) else: iterator = ((str(key), object_to_json(value)) for key, value in obj.items()) return dict(iterator) @@ -202,9 +198,7 @@ class TypedEnumDictSerializer(TypedCollectionSerializer[T]): value_type = value_types.pop() if value_type is not str: - raise JsonTypeError( - "invalid enumeration key type, expected `enum.Enum` with string values" - ) + raise JsonTypeError("invalid enumeration key type, expected `enum.Enum` with string values") def generate(self, obj: Dict[enum.Enum, T]) -> Dict[str, JsonType]: return {key.value: self.generator.generate(value) for key, value in obj.items()} @@ -218,18 +212,11 @@ class TypedSetSerializer(TypedCollectionSerializer[T]): class TypedTupleSerializer(Serializer[tuple]): item_generators: Tuple[Serializer, ...] - def __init__( - self, item_types: Tuple[type, ...], context: Optional[ModuleType] - ) -> None: - self.item_generators = tuple( - _get_serializer(item_type, context) for item_type in item_types - ) + def __init__(self, item_types: Tuple[type, ...], context: Optional[ModuleType]) -> None: + self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types) def generate(self, obj: tuple) -> List[JsonType]: - return [ - item_generator.generate(item) - for item_generator, item in zip(self.item_generators, obj) - ] + return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj, strict=False)] class CustomSerializer(Serializer): @@ -255,9 +242,7 @@ class FieldSerializer(Generic[T]): property_name: str generator: Serializer - def __init__( - self, field_name: str, property_name: str, generator: Serializer[T] - ) -> None: + def __init__(self, field_name: str, property_name: str, generator: Serializer[T]) -> None: self.field_name = field_name self.property_name = property_name self.generator = generator @@ -290,9 +275,7 @@ class TypedClassSerializer(Serializer[T]): class TypedNamedTupleSerializer(TypedClassSerializer[NamedTuple]): - def __init__( - self, class_type: Type[NamedTuple], context: Optional[ModuleType] - ) -> None: + def __init__(self, class_type: Type[NamedTuple], context: Optional[ModuleType]) -> None: super().__init__(class_type, context) @@ -365,9 +348,7 @@ class UntypedClassSerializer(Serializer): return object_dict -def create_serializer( - typ: TypeLike, context: Optional[ModuleType] = None -) -> Serializer: +def create_serializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Serializer: """ Creates a serializer engine to produce an object that can be directly converted into a JSON string. @@ -489,13 +470,7 @@ def _create_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializ return UntypedNamedTupleSerializer(typ) # fail early if caller passes an object with an exotic type - if ( - not isinstance(typ, type) - or typ is FunctionType - or typ is MethodType - or typ is type - or typ is ModuleType - ): + if not isinstance(typ, type) or typ is FunctionType or typ is MethodType or typ is type or typ is ModuleType: raise TypeError(f"object of type {typ} cannot be represented in JSON") if get_resolved_hints(typ): diff --git a/docs/openapi_generator/strong_typing/slots.py b/llama_stack/strong_typing/slots.py similarity index 88% rename from docs/openapi_generator/strong_typing/slots.py rename to llama_stack/strong_typing/slots.py index 564ffa11f..c1a3293d8 100644 --- a/docs/openapi_generator/strong_typing/slots.py +++ b/llama_stack/strong_typing/slots.py @@ -10,9 +10,7 @@ T = TypeVar("T") class SlotsMeta(type): - def __new__( - cls: Type[T], name: str, bases: Tuple[type, ...], ns: Dict[str, Any] - ) -> T: + def __new__(cls: Type[T], name: str, bases: Tuple[type, ...], ns: Dict[str, Any]) -> T: # caller may have already provided slots, in which case just retain them and keep going slots: Tuple[str, ...] = ns.get("__slots__", ()) diff --git a/docs/openapi_generator/strong_typing/topological.py b/llama_stack/strong_typing/topological.py similarity index 100% rename from docs/openapi_generator/strong_typing/topological.py rename to llama_stack/strong_typing/topological.py diff --git a/llama_stack/templates/bedrock/bedrock.py b/llama_stack/templates/bedrock/bedrock.py index af1d48b7f..0b294824d 100644 --- a/llama_stack/templates/bedrock/bedrock.py +++ b/llama_stack/templates/bedrock/bedrock.py @@ -6,10 +6,9 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models import ModelInput from llama_stack.distribution.datatypes import Provider, ToolGroupInput +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings diff --git a/llama_stack/templates/bedrock/doc_template.md b/llama_stack/templates/bedrock/doc_template.md index 2121719b7..357638ea5 100644 --- a/llama_stack/templates/bedrock/doc_template.md +++ b/llama_stack/templates/bedrock/doc_template.md @@ -55,7 +55,8 @@ docker run \ --port $LLAMA_STACK_PORT \ --env AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ --env AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ - --env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN + --env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN \ + --env AWS_DEFAULT_REGION=$AWS_DEFAULT_REGION ``` ### Via Conda @@ -66,5 +67,6 @@ llama stack run ./run.yaml \ --port $LLAMA_STACK_PORT \ --env AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ --env AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ - --env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN + --env AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN \ + --env AWS_DEFAULT_REGION=$AWS_DEFAULT_REGION ``` diff --git a/llama_stack/templates/bedrock/run.yaml b/llama_stack/templates/bedrock/run.yaml index be6c9a928..7d03b7c29 100644 --- a/llama_stack/templates/bedrock/run.yaml +++ b/llama_stack/templates/bedrock/run.yaml @@ -107,7 +107,7 @@ shields: [] vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/cerebras/cerebras.py b/llama_stack/templates/cerebras/cerebras.py index 870240feb..4f6d0c8f3 100644 --- a/llama_stack/templates/cerebras/cerebras.py +++ b/llama_stack/templates/cerebras/cerebras.py @@ -6,10 +6,9 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) diff --git a/llama_stack/templates/cerebras/run.yaml b/llama_stack/templates/cerebras/run.yaml index 05d3f4525..6afff2be2 100644 --- a/llama_stack/templates/cerebras/run.yaml +++ b/llama_stack/templates/cerebras/run.yaml @@ -109,7 +109,7 @@ shields: [] vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/dell/run-with-safety.yaml b/llama_stack/templates/dell/run-with-safety.yaml index 04c5957d4..ddec3a715 100644 --- a/llama_stack/templates/dell/run-with-safety.yaml +++ b/llama_stack/templates/dell/run-with-safety.yaml @@ -108,7 +108,7 @@ shields: vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: brave-search diff --git a/llama_stack/templates/dell/run.yaml b/llama_stack/templates/dell/run.yaml index 706444eb1..9394c94ef 100644 --- a/llama_stack/templates/dell/run.yaml +++ b/llama_stack/templates/dell/run.yaml @@ -99,7 +99,7 @@ shields: [] vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: brave-search diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index 75d103c9f..e70ccdd2d 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -85,4 +85,4 @@ shields: [] vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] diff --git a/llama_stack/templates/fireworks/fireworks.py b/llama_stack/templates/fireworks/fireworks.py index e2e2ca99c..a6809fef6 100644 --- a/llama_stack/templates/fireworks/fireworks.py +++ b/llama_stack/templates/fireworks/fireworks.py @@ -6,8 +6,6 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( ModelInput, @@ -15,6 +13,7 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) diff --git a/llama_stack/templates/fireworks/run-with-safety.yaml b/llama_stack/templates/fireworks/run-with-safety.yaml index 0fbe14a5a..8f95e9d59 100644 --- a/llama_stack/templates/fireworks/run-with-safety.yaml +++ b/llama_stack/templates/fireworks/run-with-safety.yaml @@ -164,7 +164,7 @@ shields: vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/fireworks/run.yaml b/llama_stack/templates/fireworks/run.yaml index ccf67dcbb..64229a5d8 100644 --- a/llama_stack/templates/fireworks/run.yaml +++ b/llama_stack/templates/fireworks/run.yaml @@ -153,7 +153,7 @@ shields: vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/hf-endpoint/run-with-safety.yaml b/llama_stack/templates/hf-endpoint/run-with-safety.yaml index f520a2fda..867d7a076 100644 --- a/llama_stack/templates/hf-endpoint/run-with-safety.yaml +++ b/llama_stack/templates/hf-endpoint/run-with-safety.yaml @@ -116,7 +116,7 @@ shields: vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/hf-endpoint/run.yaml b/llama_stack/templates/hf-endpoint/run.yaml index 708cb1bcc..d60acdefd 100644 --- a/llama_stack/templates/hf-endpoint/run.yaml +++ b/llama_stack/templates/hf-endpoint/run.yaml @@ -106,7 +106,7 @@ shields: [] vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/hf-serverless/run-with-safety.yaml b/llama_stack/templates/hf-serverless/run-with-safety.yaml index 7f0abf5be..e58ad15b3 100644 --- a/llama_stack/templates/hf-serverless/run-with-safety.yaml +++ b/llama_stack/templates/hf-serverless/run-with-safety.yaml @@ -116,7 +116,7 @@ shields: vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/hf-serverless/run.yaml b/llama_stack/templates/hf-serverless/run.yaml index c0b7a4c60..5045e821a 100644 --- a/llama_stack/templates/hf-serverless/run.yaml +++ b/llama_stack/templates/hf-serverless/run.yaml @@ -106,7 +106,7 @@ shields: [] vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index c5286fc6b..caac65c8c 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -118,7 +118,7 @@ shields: vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index 310585f23..bade9a076 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -107,7 +107,7 @@ shields: [] vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/meta-reference-quantized-gpu/run.yaml b/llama_stack/templates/meta-reference-quantized-gpu/run.yaml index d43cf3917..f131e8ea6 100644 --- a/llama_stack/templates/meta-reference-quantized-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-quantized-gpu/run.yaml @@ -109,7 +109,7 @@ shields: [] vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index d24c9ed48..ee22b5555 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -6,9 +6,8 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index c8ae362f5..14fb28354 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -139,7 +139,7 @@ shields: [] vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index ac5dab755..9d5bfc7a0 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -113,7 +113,7 @@ shields: vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 3a60fe61f..9ac1f3267 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -110,7 +110,7 @@ shields: [] vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/remote-vllm/run-with-safety.yaml b/llama_stack/templates/remote-vllm/run-with-safety.yaml index 1fe998a1f..dd43f21f6 100644 --- a/llama_stack/templates/remote-vllm/run-with-safety.yaml +++ b/llama_stack/templates/remote-vllm/run-with-safety.yaml @@ -118,7 +118,7 @@ shields: vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index 9d3db8a31..24cd207c7 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -107,7 +107,7 @@ shields: [] vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 39b0f3c4e..26815dcd0 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -118,7 +118,7 @@ shields: vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index 6d7477c8e..c7a9428af 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -6,14 +6,13 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.distribution.datatypes import ( ModelInput, Provider, ShieldInput, ToolGroupInput, ) +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.sambanova.sambanova import MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings diff --git a/llama_stack/templates/tgi/run-with-safety.yaml b/llama_stack/templates/tgi/run-with-safety.yaml index ed6c9ef6f..e1d85f59a 100644 --- a/llama_stack/templates/tgi/run-with-safety.yaml +++ b/llama_stack/templates/tgi/run-with-safety.yaml @@ -106,7 +106,7 @@ shields: vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/tgi/run.yaml b/llama_stack/templates/tgi/run.yaml index 8bf76f37b..fc73e0978 100644 --- a/llama_stack/templates/tgi/run.yaml +++ b/llama_stack/templates/tgi/run.yaml @@ -105,7 +105,7 @@ shields: [] vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/together/run-with-safety.yaml b/llama_stack/templates/together/run-with-safety.yaml index 298926630..f101a5d60 100644 --- a/llama_stack/templates/together/run-with-safety.yaml +++ b/llama_stack/templates/together/run-with-safety.yaml @@ -159,7 +159,7 @@ shields: vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index 920003759..8af85979d 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -148,7 +148,7 @@ shields: vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/llama_stack/templates/together/together.py b/llama_stack/templates/together/together.py index 9ec5b38ba..f7b18e32a 100644 --- a/llama_stack/templates/together/together.py +++ b/llama_stack/templates/together/together.py @@ -6,8 +6,6 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( ModelInput, @@ -15,6 +13,7 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) diff --git a/llama_stack/templates/vllm-gpu/run.yaml b/llama_stack/templates/vllm-gpu/run.yaml index 41a545e1a..cdce5510d 100644 --- a/llama_stack/templates/vllm-gpu/run.yaml +++ b/llama_stack/templates/vllm-gpu/run.yaml @@ -109,7 +109,7 @@ shields: [] vector_dbs: [] datasets: [] scoring_fns: [] -eval_tasks: [] +benchmarks: [] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/pyproject.toml b/pyproject.toml index 2f40ceac9..71af2cc99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "llama_stack" -version = "0.1.2" +version = "0.1.3" authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }] description = "Llama Stack" readme = "README.md" @@ -25,8 +25,9 @@ dependencies = [ "fire", "httpx", "huggingface-hub", - "llama-models>=0.1.2", - "llama-stack-client>=0.1.2", + "jsonschema", + "llama-models>=0.1.3", + "llama-stack-client>=0.1.3", "prompt-toolkit", "python-dotenv", "pydantic>=2", @@ -76,3 +77,66 @@ license-files = [] name = "pytorch-cpu" url = "https://download.pytorch.org/whl/cpu" explicit = true + +[tool.ruff] +line-length = 120 +exclude = [ + "./.git", + "./docs/*", + "./build", + "./scripts", + "./venv", + "*.pyi", + ".pre-commit-config.yaml", + "*.md", + ".flake8", +] + +[tool.ruff.lint] +select = [ + "B", # flake8-bugbear + "B9", # flake8-bugbear subset + "C", # comprehensions + "E", # pycodestyle + "F", # Pyflakes + "N", # Naming + "W", # Warnings + "I", # isort +] +ignore = [ + "E203", + "E305", + "E402", + "E501", # line too long + "E721", + "E741", + "F405", + "F821", + "F841", + "C408", # ignored because we like the dict keyword argument syntax + "E302", + "W291", + "E303", + "N812", # ignored because import torch.nn.functional as F is PyTorch convention + "N817", # ignored because importing using acronyms is convention (DistributedDataParallel as DDP) + "E731", # allow usage of assigning lambda expressions + # These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later. + "C901", + "C405", + "C414", + "N803", + "N999", + "C403", + "C416", + "B028", + "C419", + "C401", + "B023", + # shebang has extra meaning in fbcode lints, so I think it's not worth trying + # to line this up with executable bit + "EXE001", + "N802", # random naming hints don't need + # these ignores are from flake8-bugbear; please fix! + "B007", + "B008", +] diff --git a/requirements.txt b/requirements.txt index 497feb764..b72c240bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ # This file was autogenerated by uv via the following command: -# uv export --frozen --no-hashes --no-emit-project +# uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt annotated-types==0.7.0 anyio==4.8.0 +attrs==25.1.0 blobfile==3.0.0 certifi==2025.1.31 -chardet==5.2.0 charset-normalizer==3.4.1 click==8.1.8 colorama==0.4.6 ; sys_platform == 'win32' @@ -19,8 +19,10 @@ httpx==0.28.1 huggingface-hub==0.28.1 idna==3.10 jinja2==3.1.5 -llama-models==0.1.2 -llama-stack-client==0.1.2 +jsonschema==4.23.0 +jsonschema-specifications==2024.10.1 +llama-models==0.1.3 +llama-stack-client==0.1.3 lxml==5.3.0 markdown-it-py==3.0.0 markupsafe==3.0.2 @@ -35,14 +37,15 @@ pycryptodomex==3.21.0 pydantic==2.10.6 pydantic-core==2.27.2 pygments==2.19.1 -pypdf==5.2.0 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 pytz==2025.1 pyyaml==6.0.2 +referencing==0.36.2 regex==2024.11.6 requests==2.32.3 rich==13.9.4 +rpds-py==0.22.3 setuptools==75.8.0 six==1.17.0 sniffio==1.3.1 diff --git a/rfcs/RFC-0001-llama-stack.md b/rfcs/RFC-0001-llama-stack.md index 2ff7838c1..7ba125c36 100644 --- a/rfcs/RFC-0001-llama-stack.md +++ b/rfcs/RFC-0001-llama-stack.md @@ -1,12 +1,15 @@ # The Llama Stack API **Authors:** + * Meta: @raghotham, @ashwinb, @hjshah, @jspisak ## Summary + As part of the Llama 3.1 release, Meta is releasing an RFC for ‘Llama Stack’, a comprehensive set of interfaces / API for ML developers building on top of Llama foundation models. We are looking for feedback on where the API can be improved, any corner cases we may have missed and your general thoughts on how useful this will be. Ultimately, our hope is to create a standard for working with Llama models in order to simplify the developer experience and foster innovation across the Llama ecosystem. ## Motivation + Llama models were always intended to work as part of an overall system that can orchestrate several components, including calling external tools. Our vision is to go beyond the foundation models and give developers access to a broader system that gives them the flexibility to design and create custom offerings that align with their vision. This thinking started last year when we first introduced a system-level safety model. Meta has continued to release new components for orchestration at the system level and, most recently in Llama 3.1, we’ve introduced the Llama Guard 3 safety model that is multilingual, a prompt injection filter, Prompt Guard and refreshed v3 of our CyberSec Evals. We are also releasing a reference implementation of an agentic system to demonstrate how all the pieces fit together. While building the reference implementation, we realized that having a clean and consistent way to interface between components could be valuable not only for us but for anyone leveraging Llama models and other components as part of their system. We’ve also heard from the community as they face a similar challenge as components exist with overlapping functionality and there are incompatible interfaces and yet don't cover the end-to-end model life cycle. @@ -16,22 +19,21 @@ With these motivations, we engaged folks in industry, startups, and the broader We welcome feedback and ways to improve the proposal. We’re excited to grow the ecosystem around Llama and lower barriers for both developers and platform providers. ## Design decisions -Meta releases weights of both the pretrained and instruction fine-tuned Llama models to support several use cases. These weights can be improved - fine tuned and aligned - with curated datasets to then be deployed for inference to support specific applications. The curated datasets can be produced manually by humans or synthetically by other models or by leveraging human feedback by collecting usage data of the application itself. This results in a continuous improvement cycle where the model gets better over time. This is the model life cycle. +Meta releases weights of both the pretrained and instruction fine-tuned Llama models to support several use cases. These weights can be improved - fine tuned and aligned - with curated datasets to then be deployed for inference to support specific applications. The curated datasets can be produced manually by humans or synthetically by other models or by leveraging human feedback by collecting usage data of the application itself. This results in a continuous improvement cycle where the model gets better over time. This is the model life cycle. ### Model Lifecycle ![Figure 1: Model Life Cycle](../docs/resources/model-lifecycle.png) - For each of the operations that need to be performed (e.g. fine tuning, inference, evals etc) during the model life cycle, we identified the capabilities as toolchain APIs that are needed. Some of these capabilities are primitive operations like inference while other capabilities like synthetic data generation are composed of other capabilities. The list of APIs we have identified to support the lifecycle of Llama models is below: -- /datasets - to support creating training and evaluation data sets -- /post_training - to support creating and managing supervised finetuning (SFT) or preference optimization jobs -- /evaluations - to support creating and managing evaluations for capabilities like question answering, summarization, or text - generation -- /synthetic_data_generation - to support generating synthetic data using data generation model and a reward model -- /reward_scoring - to support synthetic data generation -- /inference - to support serving the models for applications +* /datasets - to support creating training and evaluation data sets +* /post_training - to support creating and managing supervised finetuning (SFT) or preference optimization jobs +* /evaluations - to support creating and managing evaluations for capabilities like question answering, summarization, or text - generation +* /synthetic_data_generation - to support generating synthetic data using data generation model and a reward model +* /reward_scoring - to support synthetic data generation +* /inference - to support serving the models for applications ### Agentic System @@ -41,6 +43,7 @@ In addition to the model lifecycle, we considered the different components invol Note that as of today, in the OSS world, such a “loop” is often coded explicitly via elaborate prompt engineering using a ReAct pattern (typically) or preconstructed execution graph. Llama 3.1 (and future Llamas) attempts to absorb this multi-step reasoning loop inside the main model itself. **Let's consider an example:** + 1. The user asks the system "Who played the NBA finals last year?" 1. The model "understands" that this question needs to be answered using web search. It answers this abstractly with a message of the form "Please call the search tool for me with the query: 'List finalist teams for NBA in the last year' ". Note that the model by itself does not call the tool (of course!) 1. The executor consults the set of tool implementations which have been configured by the developer to find an implementation for the "search tool". If it does not find it, it returns an error to the model. Otherwise, it executes this tool and returns the result of this tool back to the model. @@ -62,14 +65,7 @@ We define the Llama Stack as a layer cake shown below. ![Figure 3: Llama Stack](../docs/resources/llama-stack.png) - - - -The API is defined in the [YAML](../docs/resources/llama-stack-spec.yaml) and [HTML](../docs/resources/llama-stack-spec.html) files. These files were generated using the Pydantic definitions in (api/datatypes.py and api/endpoints.py) files that are in the llama-models, llama-stack, and llama-agentic-system repositories. - - - - +The API is defined in the [YAML](../docs/_static/llama-stack-spec.yaml) and [HTML](../docs/_static/llama-stack-spec.html) files. These files were generated using the Pydantic definitions in (api/datatypes.py and api/endpoints.py) files that are in the llama-models, llama-stack, and llama-agentic-system repositories. ## Sample implementations @@ -77,8 +73,8 @@ To prove out the API, we implemented a handful of use cases to make things more There is also a sample inference endpoint implementation in the [llama-stack](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/distribution/server/server.py) repository. - ## Limitations + The reference implementation for Llama Stack APIs to date only includes sample implementations using the inference API. We are planning to flesh out the design of Llama Stack Distributions (distros) by combining capabilities from different providers into a single vertically integrated stack. We plan to implement other APIs and, of course, we’d love contributions!! Thank you in advance for your feedback, support and contributions to make this a better API. diff --git a/tests/client-sdk/README.md b/tests/client-sdk/README.md index d4d439d96..703d06a39 100644 --- a/tests/client-sdk/README.md +++ b/tests/client-sdk/README.md @@ -3,19 +3,16 @@ You can run llama stack integration tests on either a Llama Stack Library or a L To test on a Llama Stack library with certain configuration, run ```bash -LLAMA_STACK_CONFIG=./llama_stack/templates/cerebras/run.yaml -pytest -s -v tests/client-sdk/inference/ +LLAMA_STACK_CONFIG=./llama_stack/templates/cerebras/run.yaml pytest -s -v tests/client-sdk/inference/ ``` or just the template name ```bash -LLAMA_STACK_CONFIG=together -pytest -s -v tests/client-sdk/inference/ +LLAMA_STACK_CONFIG=together pytest -s -v tests/client-sdk/inference/ ``` To test on a Llama Stack endpoint, run ```bash -LLAMA_STACK_BASE_URL=http//localhost:8089 -pytest -s -v tests/client-sdk/inference +LLAMA_STACK_BASE_URL=http://localhost:8089 pytest -s -v tests/client-sdk/inference ``` ## Report Generation diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index f42341f72..e5380d357 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -98,7 +98,6 @@ def agent_config(llama_stack_client, text_model_id): }, }, toolgroups=[], - tool_choice="auto", input_shields=available_shields, output_shields=available_shields, enable_session_persistence=False, @@ -319,7 +318,39 @@ def test_custom_tool(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "-100" in logs_str - assert "CustomTool" in logs_str + assert "get_boiling_point" in logs_str + + +def test_tool_choice(llama_stack_client, agent_config): + data = [ + ("required", '{"type": "function"'), + ("none", None), + ("get_boiling_point", '{"type": "function", "name": "get_boiling_point"'), + ] + client_tool = TestClientTool() + for tool_choice, expected_tool in data: + agent_config["tool_config"] = {"tool_choice": tool_choice} + agent_config["client_tools"] = [client_tool.get_tool_definition()] + + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "What is the boiling point of polyjuice?", + }, + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + if expected_tool: + assert expected_tool in logs_str + else: + assert '{"type": "function"' not in logs_str # TODO: fix this flaky test @@ -403,7 +434,7 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config): logs_str = "".join(logs) print(logs_str) assert "-100" in logs_str - assert "CustomTool" in logs_str + assert "get_boiling_point" in logs_str def test_rag_agent(llama_stack_client, agent_config): @@ -527,3 +558,42 @@ def test_rag_and_code_agent(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert f"Tool:{tool_name}" in logs_str + + +def test_create_turn_response(llama_stack_client, agent_config): + client_tool = TestClientTool() + agent_config = { + **agent_config, + "input_shields": [], + "output_shields": [], + "client_tools": [client_tool.get_tool_definition()], + } + + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "Call get_boiling_point and answer What is the boiling point of polyjuice?", + }, + ], + session_id=session_id, + stream=False, + ) + steps = response.steps + assert len(steps) == 3 + assert steps[0].step_type == "inference" + assert steps[1].step_type == "tool_execution" + assert steps[1].tool_calls[0].tool_name == "get_boiling_point" + assert steps[2].step_type == "inference" + + last_step_completed_at = None + for step in steps: + if last_step_completed_at is None: + last_step_completed_at = step.completed_at + else: + assert last_step_completed_at < step.started_at + assert step.started_at < step.completed_at + last_step_completed_at = step.completed_at diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index c931ca255..6a113c463 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -247,6 +247,40 @@ def test_text_chat_completion_with_tool_calling_and_streaming( assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]" +def test_text_chat_completion_with_tool_choice_required( + llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format, inference_provider_type +): + response = llama_stack_client.inference.chat_completion( + model_id=text_model_id, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather like in San Francisco?"}, + ], + tools=[get_weather_tool_definition], + tool_config={"tool_choice": "required", "tool_prompt_format": provider_tool_format}, + stream=True, + ) + tool_invocation_content = extract_tool_invocation_content(response) + assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]" + + +def test_text_chat_completion_with_tool_choice_none( + llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format +): + response = llama_stack_client.inference.chat_completion( + model_id=text_model_id, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather like in San Francisco?"}, + ], + tools=[get_weather_tool_definition], + tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format}, + stream=True, + ) + tool_invocation_content = extract_tool_invocation_content(response) + assert tool_invocation_content == "" + + def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type): class AnswerFormat(BaseModel): first_name: str diff --git a/tests/client-sdk/report.py b/tests/client-sdk/report.py index 543562541..d36fa827f 100644 --- a/tests/client-sdk/report.py +++ b/tests/client-sdk/report.py @@ -13,8 +13,12 @@ from typing import Optional from urllib.parse import urlparse import pytest -from llama_models.datatypes import CoreModelId -from llama_models.sku_list import ( +from metadata import API_MAPS +from pytest import CollectReport +from termcolor import cprint + +from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_list import ( all_registered_models, llama3_1_instruct_models, llama3_2_instruct_models, @@ -22,10 +26,6 @@ from llama_models.sku_list import ( llama3_instruct_models, safety_models, ) -from metadata import API_MAPS -from pytest import CollectReport -from termcolor import cprint - from llama_stack.providers.datatypes import Api from llama_stack.providers.tests.env import get_env_or_fail diff --git a/tests/client-sdk/vector_io/test_vector_io.py b/tests/client-sdk/vector_io/test_vector_io.py index c5be4ab3f..c7e4040b6 100644 --- a/tests/client-sdk/vector_io/test_vector_io.py +++ b/tests/client-sdk/vector_io/test_vector_io.py @@ -8,7 +8,11 @@ import random import pytest -INLINE_VECTOR_DB_PROVIDERS = ["faiss", "sqlite_vec"] +INLINE_VECTOR_DB_PROVIDERS = [ + "faiss", + # TODO: add sqlite_vec to templates + # "sqlite_vec", +] @pytest.fixture(scope="function") diff --git a/uv.lock b/uv.lock index 97ae52124..336d67c0b 100644 --- a/uv.lock +++ b/uv.lock @@ -701,7 +701,7 @@ wheels = [ [[package]] name = "llama-models" -version = "0.1.2" +version = "0.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jinja2" }, @@ -710,20 +710,21 @@ dependencies = [ { name = "pyyaml" }, { name = "tiktoken" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b5/f2/ed8310d4677cd38ab45ffba45aea2a4e9882b640045ad9c3198ac69e5a85/llama_models-0.1.2.tar.gz", hash = "sha256:1266eaec7a8db336e4ed034d2b494189ccb7fd6d6b7aefe874eee749a4340b9b", size = 1608069 } +sdist = { url = "https://files.pythonhosted.org/packages/0b/39/b8e2c02bc5ce1c0ba4e249532e0eb384ad7dae54a8f53198c8ff9aded41e/llama_models-0.1.3.tar.gz", hash = "sha256:2f339e67b8bbd98729bd2052c2cb8a916ef8f7d8a05337febad8879c6718c24a", size = 1568353 } wheels = [ - { url = "https://files.pythonhosted.org/packages/55/a7/34b9e88ef4109759c8881f43b8006139e3d13d54c440b8c571b253655f54/llama_models-0.1.2-py3-none-any.whl", hash = "sha256:8aa5287d1c6325698991ff677e71148cac347e07493bb5b3ab891e614b89e1f8", size = 1651273 }, + { url = "https://files.pythonhosted.org/packages/8c/df/a39f85cce6fcab962f7a7113063a6b2b08d0f66ac8ba4b9b12f21f398885/llama_models-0.1.3-py3-none-any.whl", hash = "sha256:87d92027e27c6b3e905158751758bcb7dabbdca1d995592e8e46fd2160daa844", size = 1587292 }, ] [[package]] name = "llama-stack" -version = "0.1.2" +version = "0.1.3" source = { editable = "." } dependencies = [ { name = "blobfile" }, { name = "fire" }, { name = "httpx" }, { name = "huggingface-hub" }, + { name = "jsonschema" }, { name = "llama-models" }, { name = "llama-stack-client" }, { name = "prompt-toolkit" }, @@ -768,8 +769,9 @@ requires-dist = [ { name = "fire" }, { name = "httpx" }, { name = "huggingface-hub" }, - { name = "llama-models", specifier = ">=0.1.2" }, - { name = "llama-stack-client", specifier = ">=0.1.2" }, + { name = "jsonschema" }, + { name = "llama-models", specifier = ">=0.1.3" }, + { name = "llama-stack-client", specifier = ">=0.1.3" }, { name = "myst-parser", marker = "extra == 'docs'" }, { name = "nbval", marker = "extra == 'dev'" }, { name = "pre-commit", marker = "extra == 'dev'" }, @@ -798,7 +800,7 @@ requires-dist = [ [[package]] name = "llama-stack-client" -version = "0.1.2" +version = "0.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -815,9 +817,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9e/75/8b41a3026c871a8650cd8d2cfda9f891a9163458813574f36518bb40afe4/llama_stack_client-0.1.2.tar.gz", hash = "sha256:94277ddae52be557d771dcdc15d85af9012b5aa87439dd69ec1dc0ff486b0c8e", size = 188023 } +sdist = { url = "https://files.pythonhosted.org/packages/23/bb/f8b21745fcae811d75685202fe127c269f8387ff6374cf8f9b0be9b7eaa7/llama_stack_client-0.1.3.tar.gz", hash = "sha256:8ba46e199ac1a0e0bdcbe55fc776dd0b8f55771418c5f8bf7b419b7a0077fe7a", size = 191842 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c4/32/3a3a97eecff1f1e3a1dc90e9b00681abea11ec4f43a7ca549981261e18b6/llama_stack_client-0.1.2-py3-none-any.whl", hash = "sha256:85ff0fb57a62d7d0470cfaa2b07a595c9fb3483297944d5e5a066db850d38ccd", size = 359415 }, + { url = "https://files.pythonhosted.org/packages/88/52/3ef8405daad5649f11b5708f1df9eca4fa229e499ac198a99c42f1075a08/llama_stack_client-0.1.3-py3-none-any.whl", hash = "sha256:e7b66051918bc0685dfee6103d3efbcec3ae193b3e67edf025cd088539463245", size = 366471 }, ] [[package]]