diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 39505ba11..59d18b3be 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,6 +1,8 @@ name: Unit Tests on: + push: + branches: [ main ] pull_request: branches: [ main ] workflow_dispatch: diff --git a/README.md b/README.md index b24e69514..6e1fd088e 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ [![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/) [![License](https://img.shields.io/pypi/l/llama_stack.svg)](https://github.com/meta-llama/llama-stack/blob/main/LICENSE) [![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/llama-stack) +![Unit](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main) [**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 740bff6e4..821e5ed53 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4556,7 +4556,7 @@ "metrics": { "type": "array", "items": { - "$ref": "#/components/schemas/MetricEvent" + "$ref": "#/components/schemas/MetricInResponse" } }, "completion_message": { @@ -4578,46 +4578,9 @@ "title": "ChatCompletionResponse", "description": "Response from a chat completion request." }, - "MetricEvent": { + "MetricInResponse": { "type": "object", "properties": { - "trace_id": { - "type": "string" - }, - "span_id": { - "type": "string" - }, - "timestamp": { - "type": "string", - "format": "date-time" - }, - "attributes": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" - } - ] - } - }, - "type": { - "type": "string", - "const": "metric", - "default": "metric" - }, "metric": { "type": "string" }, @@ -4637,15 +4600,10 @@ }, "additionalProperties": false, "required": [ - "trace_id", - "span_id", - "timestamp", - "type", "metric", - "value", - "unit" + "value" ], - "title": "MetricEvent" + "title": "MetricInResponse" }, "TokenLogProbs": { "type": "object", @@ -4722,6 +4680,12 @@ "CompletionResponse": { "type": "object", "properties": { + "metrics": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MetricInResponse" + } + }, "content": { "type": "string", "description": "The generated completion text" @@ -4931,7 +4895,7 @@ "metrics": { "type": "array", "items": { - "$ref": "#/components/schemas/MetricEvent" + "$ref": "#/components/schemas/MetricInResponse" } }, "event": { @@ -5089,6 +5053,12 @@ "CompletionResponseStreamChunk": { "type": "object", "properties": { + "metrics": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MetricInResponse" + } + }, "delta": { "type": "string", "description": "New content generated since last chunk. This can be one or more tokens." @@ -8501,6 +8471,75 @@ ], "title": "LogSeverity" }, + "MetricEvent": { + "type": "object", + "properties": { + "trace_id": { + "type": "string" + }, + "span_id": { + "type": "string" + }, + "timestamp": { + "type": "string", + "format": "date-time" + }, + "attributes": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" + } + ] + } + }, + "type": { + "type": "string", + "const": "metric", + "default": "metric" + }, + "metric": { + "type": "string" + }, + "value": { + "oneOf": [ + { + "type": "integer" + }, + { + "type": "number" + } + ] + }, + "unit": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "trace_id", + "span_id", + "timestamp", + "type", + "metric", + "value", + "unit" + ], + "title": "MetricEvent" + }, "SpanEndPayload": { "type": "object", "properties": { @@ -9625,21 +9664,11 @@ "type": "object", "properties": { "tool_responses": { - "oneOf": [ - { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolResponse" - } - }, - { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolResponseMessage" - } - } - ], - "description": "The tool call responses to resume the turn with. NOTE: ToolResponseMessage will be deprecated. Use ToolResponse." + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolResponse" + }, + "description": "The tool call responses to resume the turn with." }, "stream": { "type": "boolean", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index c6f3c2327..21625827a 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -3105,7 +3105,7 @@ components: metrics: type: array items: - $ref: '#/components/schemas/MetricEvent' + $ref: '#/components/schemas/MetricInResponse' completion_message: $ref: '#/components/schemas/CompletionMessage' description: The complete response message @@ -3120,29 +3120,9 @@ components: - completion_message title: ChatCompletionResponse description: Response from a chat completion request. - MetricEvent: + MetricInResponse: type: object properties: - trace_id: - type: string - span_id: - type: string - timestamp: - type: string - format: date-time - attributes: - type: object - additionalProperties: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' - type: - type: string - const: metric - default: metric metric: type: string value: @@ -3153,14 +3133,9 @@ components: type: string additionalProperties: false required: - - trace_id - - span_id - - timestamp - - type - metric - value - - unit - title: MetricEvent + title: MetricInResponse TokenLogProbs: type: object properties: @@ -3217,6 +3192,10 @@ components: CompletionResponse: type: object properties: + metrics: + type: array + items: + $ref: '#/components/schemas/MetricInResponse' content: type: string description: The generated completion text @@ -3416,7 +3395,7 @@ components: metrics: type: array items: - $ref: '#/components/schemas/MetricEvent' + $ref: '#/components/schemas/MetricInResponse' event: $ref: '#/components/schemas/ChatCompletionResponseEvent' description: The event containing the new content @@ -3535,6 +3514,10 @@ components: CompletionResponseStreamChunk: type: object properties: + metrics: + type: array + items: + $ref: '#/components/schemas/MetricInResponse' delta: type: string description: >- @@ -5784,6 +5767,47 @@ components: - error - critical title: LogSeverity + MetricEvent: + type: object + properties: + trace_id: + type: string + span_id: + type: string + timestamp: + type: string + format: date-time + attributes: + type: object + additionalProperties: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + type: + type: string + const: metric + default: metric + metric: + type: string + value: + oneOf: + - type: integer + - type: number + unit: + type: string + additionalProperties: false + required: + - trace_id + - span_id + - timestamp + - type + - metric + - value + - unit + title: MetricEvent SpanEndPayload: type: object properties: @@ -6495,16 +6519,11 @@ components: type: object properties: tool_responses: - oneOf: - - type: array - items: - $ref: '#/components/schemas/ToolResponse' - - type: array - items: - $ref: '#/components/schemas/ToolResponseMessage' + type: array + items: + $ref: '#/components/schemas/ToolResponse' description: >- - The tool call responses to resume the turn with. NOTE: ToolResponseMessage - will be deprecated. Use ToolResponse. + The tool call responses to resume the turn with. stream: type: boolean description: Whether to stream the response. diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 1170a56d5..5cc910a55 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -370,7 +370,7 @@ class AgentTurnResumeRequest(BaseModel): agent_id: str session_id: str turn_id: str - tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]] + tool_responses: List[ToolResponse] stream: Optional[bool] = False @@ -449,7 +449,7 @@ class Agents(Protocol): agent_id: str, session_id: str, turn_id: str, - tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]], + tool_responses: List[ToolResponse], stream: Optional[bool] = False, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: """Resume an agent turn with executed tool call responses. @@ -460,7 +460,6 @@ class Agents(Protocol): :param session_id: The ID of the session to resume. :param turn_id: The ID of the turn to resume. :param tool_responses: The tool call responses to resume the turn with. - NOTE: ToolResponseMessage will be deprecated. Use ToolResponse. :param stream: Whether to stream the response. :returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects. """ diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index fe75677e7..cbea57e79 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -96,6 +96,13 @@ class MetricEvent(EventCommon): unit: str +@json_schema_type +class MetricInResponse(BaseModel): + metric: str + value: Union[int, float] + unit: Optional[str] = None + + # This is a short term solution to allow inference API to return metrics # The ideal way to do this is to have a way for all response types to include metrics # and all metric events logged to the telemetry API to be inlcuded with the response @@ -117,7 +124,7 @@ class MetricEvent(EventCommon): class MetricResponseMixin(BaseModel): - metrics: Optional[List[MetricEvent]] = None + metrics: Optional[List[MetricInResponse]] = None @json_schema_type diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 68b8e55cb..22a1e46f9 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -48,7 +48,7 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield -from llama_stack.apis.telemetry import MetricEvent, Telemetry +from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.tools import ( RAGDocument, RAGQueryConfig, @@ -206,12 +206,12 @@ class InferenceRouter(Inference): completion_tokens: int, total_tokens: int, model: Model, - ) -> List[MetricEvent]: + ) -> List[MetricInResponse]: metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) if self.telemetry: for metric in metrics: await self.telemetry.log_event(metric) - return metrics + return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] async def _count_tokens( self, @@ -238,7 +238,6 @@ class InferenceRouter(Inference): tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: logger.debug( - "core", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) if sampling_params is None: diff --git a/llama_stack/distribution/utils/context.py b/llama_stack/distribution/utils/context.py index 107ce7127..2f32afba2 100644 --- a/llama_stack/distribution/utils/context.py +++ b/llama_stack/distribution/utils/context.py @@ -19,7 +19,7 @@ def preserve_contexts_async_generator( and we need to preserve the context across the event loop boundary. """ - async def wrapper(): + async def wrapper() -> AsyncGenerator[T, None]: while True: try: item = await gen.__anext__() diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index fedd695c1..1d9f54e96 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -218,18 +218,10 @@ class ChatAgent(ShieldRunnerMixin): steps = [] messages = await self.get_messages_from_turns(turns) if is_resume: - if isinstance(request.tool_responses[0], ToolResponseMessage): - tool_response_messages = request.tool_responses - tool_responses = [ - ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content) - for x in request.tool_responses - ] - else: - tool_response_messages = [ - ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content) - for x in request.tool_responses - ] - tool_responses = request.tool_responses + tool_response_messages = [ + ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content) + for x in request.tool_responses + ] messages.extend(tool_response_messages) last_turn = turns[-1] last_turn_messages = self.turn_to_messages(last_turn) @@ -252,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin): step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), turn_id=request.turn_id, tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), - tool_responses=tool_responses, + tool_responses=request.tool_responses, completed_at=now, started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index c24b14e35..5ca123595 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -172,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents): agent_id: str, session_id: str, turn_id: str, - tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]], + tool_responses: List[ToolResponse], stream: Optional[bool] = False, ) -> AsyncGenerator: request = AgentTurnResumeRequest(