feat: add MetricResponseMixin to chat completion response types (#1050)

# What does this PR do?
Defines a MetricResponseMixin which can be inherited by any response
class. Adds it to chat completion response types.


This is a short term solution to allow inference API to return metrics
The ideal way to do this is to have a way for all response types to
include metrics
and all metric events logged to the telemetry API to be included with
the response
To do this, we will need to augment all response types with a metrics
field.
We have hit a blocker from stainless SDK that prevents us from doing
this.
The blocker is that if we were to augment the response types that have a
data field
in them like so
class ListModelsResponse(BaseModel):
    metrics: Optional[List[MetricEvent]] = None
    data: List[Models]
    ...
The client SDK will need to access the data by using a .data field,
which is not
ergonomic. Stainless SDK does support unwrapping the response type, but
it
requires that the response type to only have a single field.

We will need a way in the client SDK to signal that the metrics are
needed
and if they are needed, the client SDK has to return the full response
type
without unwrapping it.

## Test Plan
sh run_openapi_generator.sh ./
sh stainless_sync.sh dineshyv/dev add-metrics-to-resp-v4

LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/fireworks/fireworks-run.yaml"
pytest -v tests/client-sdk/agents/test_agents.py
This commit is contained in:
Dinesh Yeduguru 2025-02-11 14:58:12 -08:00 committed by GitHub
parent 96c88397da
commit ab7f802698
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 161 additions and 116 deletions

View file

@ -3106,6 +3106,12 @@
"ChatCompletionResponse": {
"type": "object",
"properties": {
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricEvent"
}
},
"completion_message": {
"$ref": "#/components/schemas/CompletionMessage",
"description": "The complete response message"
@ -3124,6 +3130,77 @@
],
"description": "Response from a chat completion request."
},
"MetricEvent": {
"type": "object",
"properties": {
"trace_id": {
"type": "string"
},
"span_id": {
"type": "string"
},
"timestamp": {
"type": "string",
"format": "date-time"
},
"attributes": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"type": {
"type": "string",
"const": "metric",
"default": "metric"
},
"metric": {
"type": "string"
},
"value": {
"oneOf": [
{
"type": "integer"
},
{
"type": "number"
}
]
},
"unit": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"trace_id",
"span_id",
"timestamp",
"type",
"metric",
"value",
"unit"
]
},
"TokenLogProbs": {
"type": "object",
"properties": {
@ -3388,6 +3465,12 @@
"ChatCompletionResponseStreamChunk": {
"type": "object",
"properties": {
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricEvent"
}
},
"event": {
"$ref": "#/components/schemas/ChatCompletionResponseEvent",
"description": "The event containing the new content"
@ -6374,77 +6457,6 @@
"critical"
]
},
"MetricEvent": {
"type": "object",
"properties": {
"trace_id": {
"type": "string"
},
"span_id": {
"type": "string"
},
"timestamp": {
"type": "string",
"format": "date-time"
},
"attributes": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"type": {
"type": "string",
"const": "metric",
"default": "metric"
},
"metric": {
"type": "string"
},
"value": {
"oneOf": [
{
"type": "integer"
},
{
"type": "number"
}
]
},
"unit": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"trace_id",
"span_id",
"timestamp",
"type",
"metric",
"value",
"unit"
]
},
"SpanEndPayload": {
"type": "object",
"properties": {

View file

@ -1925,6 +1925,10 @@ components:
ChatCompletionResponse:
type: object
properties:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricEvent'
completion_message:
$ref: '#/components/schemas/CompletionMessage'
description: The complete response message
@ -1938,6 +1942,47 @@ components:
required:
- completion_message
description: Response from a chat completion request.
MetricEvent:
type: object
properties:
trace_id:
type: string
span_id:
type: string
timestamp:
type: string
format: date-time
attributes:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type:
type: string
const: metric
default: metric
metric:
type: string
value:
oneOf:
- type: integer
- type: number
unit:
type: string
additionalProperties: false
required:
- trace_id
- span_id
- timestamp
- type
- metric
- value
- unit
TokenLogProbs:
type: object
properties:
@ -2173,6 +2218,10 @@ components:
ChatCompletionResponseStreamChunk:
type: object
properties:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricEvent'
event:
$ref: '#/components/schemas/ChatCompletionResponseEvent'
description: The event containing the new content
@ -4070,47 +4119,6 @@ components:
- warn
- error
- critical
MetricEvent:
type: object
properties:
trace_id:
type: string
span_id:
type: string
timestamp:
type: string
format: date-time
attributes:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type:
type: string
const: metric
default: metric
metric:
type: string
value:
oneOf:
- type: integer
- type: number
unit:
type: string
additionalProperties: false
required:
- trace_id
- span_id
- timestamp
- type
- metric
- value
- unit
SpanEndPayload:
type: object
properties:

View file

@ -13,8 +13,8 @@ from typing import (
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
runtime_checkable,
)
from llama_models.llama3.api.datatypes import (
@ -31,6 +31,7 @@ from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -357,7 +358,7 @@ class ChatCompletionRequest(BaseModel):
@json_schema_type
class ChatCompletionResponseStreamChunk(BaseModel):
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
"""A chunk of a streamed chat completion response.
:param event: The event containing the new content
@ -367,7 +368,7 @@ class ChatCompletionResponseStreamChunk(BaseModel):
@json_schema_type
class ChatCompletionResponse(BaseModel):
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
"""Response from a chat completion request.
:param completion_message: The complete response message

View file

@ -13,8 +13,8 @@ from typing import (
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
runtime_checkable,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
@ -94,6 +94,30 @@ class MetricEvent(EventCommon):
unit: str
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be inlcuded with the response
# To do this, we will need to augment all response types with a metrics field.
# We have hit a blocker from stainless SDK that prevents us from doing this.
# The blocker is that if we were to augment the response types that have a data field
# in them like so
# class ListModelsResponse(BaseModel):
# metrics: Optional[List[MetricEvent]] = None
# data: List[Models]
# ...
# The client SDK will need to access the data by using a .data field, which is not
# ergonomic. Stainless SDK does support unwrapping the response type, but it
# requires that the response type to only have a single field.
# We will need a way in the client SDK to signal that the metrics are needed
# and if they are needed, the client SDK has to return the full response type
# without unwrapping it.
class MetricResponseMixin(BaseModel):
metrics: Optional[List[MetricEvent]] = None
@json_schema_type
class StructuredLogType(Enum):
SPAN_START = "span_start"