From 12cbed16178b157e45d30ffff20fc0038fe573ce Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 18 Dec 2024 10:32:25 -0800 Subject: [PATCH] Register Message and ResponseFormat --- docs/resources/llama-stack-spec.html | 336 ++++++++---------------- docs/resources/llama-stack-spec.yaml | 162 +++++------- llama_stack/apis/inference/inference.py | 32 ++- 3 files changed, 195 insertions(+), 335 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 050a16223..33112012b 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -2598,6 +2598,22 @@ } ] }, + "Message": { + "oneOf": [ + { + "$ref": "#/components/schemas/UserMessage" + }, + { + "$ref": "#/components/schemas/SystemMessage" + }, + { + "$ref": "#/components/schemas/ToolResponseMessage" + }, + { + "$ref": "#/components/schemas/CompletionMessage" + } + ] + }, "SamplingParams": { "type": "object", "properties": { @@ -2936,20 +2952,7 @@ "items": { "type": "array", "items": { - "oneOf": [ - { - "$ref": "#/components/schemas/UserMessage" - }, - { - "$ref": "#/components/schemas/SystemMessage" - }, - { - "$ref": "#/components/schemas/ToolResponseMessage" - }, - { - "$ref": "#/components/schemas/CompletionMessage" - } - ] + "$ref": "#/components/schemas/Message" } } }, @@ -3059,6 +3062,90 @@ "job_uuid" ] }, + "ResponseFormat": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "json_schema", + "default": "json_schema" + }, + "json_schema": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "type", + "json_schema" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "grammar", + "default": "grammar" + }, + "bnf": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "type", + "bnf" + ] + } + ] + }, "ChatCompletionRequest": { "type": "object", "properties": { @@ -3068,20 +3155,7 @@ "messages": { "type": "array", "items": { - "oneOf": [ - { - "$ref": "#/components/schemas/UserMessage" - }, - { - "$ref": "#/components/schemas/SystemMessage" - }, - { - "$ref": "#/components/schemas/ToolResponseMessage" - }, - { - "$ref": "#/components/schemas/CompletionMessage" - } - ] + "$ref": "#/components/schemas/Message" } }, "sampling_params": { @@ -3100,88 +3174,7 @@ "$ref": "#/components/schemas/ToolPromptFormat" }, "response_format": { - "oneOf": [ - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "json_schema", - "default": "json_schema" - }, - "json_schema": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "json_schema" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "grammar", - "default": "grammar" - }, - "bnf": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "bnf" - ] - } - ] + "$ref": "#/components/schemas/ResponseFormat" }, "stream": { "type": "boolean" @@ -3336,88 +3329,7 @@ "$ref": "#/components/schemas/SamplingParams" }, "response_format": { - "oneOf": [ - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "json_schema", - "default": "json_schema" - }, - "json_schema": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "json_schema" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "grammar", - "default": "grammar" - }, - "bnf": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "bnf" - ] - } - ] + "$ref": "#/components/schemas/ResponseFormat" }, "stream": { "type": "boolean" @@ -7285,20 +7197,7 @@ "messages": { "type": "array", "items": { - "oneOf": [ - { - "$ref": "#/components/schemas/UserMessage" - }, - { - "$ref": "#/components/schemas/SystemMessage" - }, - { - "$ref": "#/components/schemas/ToolResponseMessage" - }, - { - "$ref": "#/components/schemas/CompletionMessage" - } - ] + "$ref": "#/components/schemas/Message" } }, "params": { @@ -7664,20 +7563,7 @@ "dialogs": { "type": "array", "items": { - "oneOf": [ - { - "$ref": "#/components/schemas/UserMessage" - }, - { - "$ref": "#/components/schemas/SystemMessage" - }, - { - "$ref": "#/components/schemas/ToolResponseMessage" - }, - { - "$ref": "#/components/schemas/CompletionMessage" - } - ] + "$ref": "#/components/schemas/Message" } }, "filtering_function": { @@ -8136,6 +8022,10 @@ "name": "MemoryToolDefinition", "description": "" }, + { + "name": "Message", + "description": "" + }, { "name": "MetricEvent", "description": "" @@ -8254,6 +8144,10 @@ "name": "RegisterShieldRequest", "description": "" }, + { + "name": "ResponseFormat", + "description": "" + }, { "name": "RestAPIExecutionConfig", "description": "" @@ -8598,6 +8492,7 @@ "MemoryBankDocument", "MemoryRetrievalStep", "MemoryToolDefinition", + "Message", "MetricEvent", "Model", "ModelCandidate", @@ -8626,6 +8521,7 @@ "RegisterModelRequest", "RegisterScoringFunctionRequest", "RegisterShieldRequest", + "ResponseFormat", "RestAPIExecutionConfig", "RestAPIMethod", "RouteInfo", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index b5a209e89..abd57e17e 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -313,11 +313,7 @@ components: messages_batch: items: items: - oneOf: - - $ref: '#/components/schemas/UserMessage' - - $ref: '#/components/schemas/SystemMessage' - - $ref: '#/components/schemas/ToolResponseMessage' - - $ref: '#/components/schemas/CompletionMessage' + $ref: '#/components/schemas/Message' type: array type: array model: @@ -422,56 +418,12 @@ components: type: object messages: items: - oneOf: - - $ref: '#/components/schemas/UserMessage' - - $ref: '#/components/schemas/SystemMessage' - - $ref: '#/components/schemas/ToolResponseMessage' - - $ref: '#/components/schemas/CompletionMessage' + $ref: '#/components/schemas/Message' type: array model_id: type: string response_format: - oneOf: - - additionalProperties: false - properties: - json_schema: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - type: - const: json_schema - default: json_schema - type: string - required: - - type - - json_schema - type: object - - additionalProperties: false - properties: - bnf: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - type: - const: grammar - default: grammar - type: string - required: - - type - - bnf - type: object + $ref: '#/components/schemas/ResponseFormat' sampling_params: $ref: '#/components/schemas/SamplingParams' stream: @@ -598,47 +550,7 @@ components: model_id: type: string response_format: - oneOf: - - additionalProperties: false - properties: - json_schema: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - type: - const: json_schema - default: json_schema - type: string - required: - - type - - json_schema - type: object - - additionalProperties: false - properties: - bnf: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - type: - const: grammar - default: grammar - type: string - required: - - type - - bnf - type: object + $ref: '#/components/schemas/ResponseFormat' sampling_params: $ref: '#/components/schemas/SamplingParams' stream: @@ -1467,6 +1379,12 @@ components: - max_tokens_in_context - max_chunks type: object + Message: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/SystemMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + - $ref: '#/components/schemas/CompletionMessage' MetricEvent: additionalProperties: false properties: @@ -2121,6 +2039,48 @@ components: required: - shield_id type: object + ResponseFormat: + oneOf: + - additionalProperties: false + properties: + json_schema: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + type: + const: json_schema + default: json_schema + type: string + required: + - type + - json_schema + type: object + - additionalProperties: false + properties: + bnf: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + type: + const: grammar + default: grammar + type: string + required: + - type + - bnf + type: object RestAPIExecutionConfig: additionalProperties: false properties: @@ -2203,11 +2163,7 @@ components: properties: messages: items: - oneOf: - - $ref: '#/components/schemas/UserMessage' - - $ref: '#/components/schemas/SystemMessage' - - $ref: '#/components/schemas/ToolResponseMessage' - - $ref: '#/components/schemas/CompletionMessage' + $ref: '#/components/schemas/Message' type: array params: additionalProperties: @@ -2744,11 +2700,7 @@ components: properties: dialogs: items: - oneOf: - - $ref: '#/components/schemas/UserMessage' - - $ref: '#/components/schemas/SystemMessage' - - $ref: '#/components/schemas/ToolResponseMessage' - - $ref: '#/components/schemas/CompletionMessage' + $ref: '#/components/schemas/Message' type: array filtering_function: enum: @@ -5024,6 +4976,8 @@ tags: - description: name: MemoryToolDefinition +- description: + name: Message - description: name: MetricEvent - description: @@ -5108,6 +5062,8 @@ tags: - description: name: RegisterShieldRequest +- description: + name: ResponseFormat - description: name: RestAPIExecutionConfig @@ -5371,6 +5327,7 @@ x-tagGroups: - MemoryBankDocument - MemoryRetrievalStep - MemoryToolDefinition + - Message - MetricEvent - Model - ModelCandidate @@ -5399,6 +5356,7 @@ x-tagGroups: - RegisterModelRequest - RegisterScoringFunctionRequest - RegisterShieldRequest + - ResponseFormat - RestAPIExecutionConfig - RestAPIMethod - RouteInfo diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index c481d04d7..28b9d9106 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -25,7 +25,7 @@ from llama_models.llama3.api.datatypes import ( ToolPromptFormat, ) -from llama_models.schema_utils import json_schema_type, webmethod +from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated @@ -100,15 +100,18 @@ class CompletionMessage(BaseModel): tool_calls: List[ToolCall] = Field(default_factory=list) -Message = Annotated[ - Union[ - UserMessage, - SystemMessage, - ToolResponseMessage, - CompletionMessage, +Message = register_schema( + Annotated[ + Union[ + UserMessage, + SystemMessage, + ToolResponseMessage, + CompletionMessage, + ], + Field(discriminator="role"), ], - Field(discriminator="role"), -] + name="Message", +) @json_schema_type @@ -187,10 +190,13 @@ class GrammarResponseFormat(BaseModel): bnf: Dict[str, Any] -ResponseFormat = Annotated[ - Union[JsonSchemaResponseFormat, GrammarResponseFormat], - Field(discriminator="type"), -] +ResponseFormat = register_schema( + Annotated[ + Union[JsonSchemaResponseFormat, GrammarResponseFormat], + Field(discriminator="type"), + ], + name="ResponseFormat", +) @json_schema_type