mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
Merge remote-tracking branch 'origin/main' into support_more_data_format
This commit is contained in:
commit
8d7bb1140f
20 changed files with 381 additions and 414 deletions
|
@ -3843,8 +3843,8 @@
|
||||||
"properties": {
|
"properties": {
|
||||||
"role": {
|
"role": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"const": "ipython",
|
"const": "tool",
|
||||||
"default": "ipython"
|
"default": "tool"
|
||||||
},
|
},
|
||||||
"call_id": {
|
"call_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
|
@ -4185,14 +4185,7 @@
|
||||||
"$ref": "#/components/schemas/ChatCompletionResponseEventType"
|
"$ref": "#/components/schemas/ChatCompletionResponseEventType"
|
||||||
},
|
},
|
||||||
"delta": {
|
"delta": {
|
||||||
"oneOf": [
|
"$ref": "#/components/schemas/ContentDelta"
|
||||||
{
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/ToolCallDelta"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
"logprobs": {
|
"logprobs": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
|
@ -4232,6 +4225,50 @@
|
||||||
],
|
],
|
||||||
"title": "SSE-stream of these events."
|
"title": "SSE-stream of these events."
|
||||||
},
|
},
|
||||||
|
"ContentDelta": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "text",
|
||||||
|
"default": "text"
|
||||||
|
},
|
||||||
|
"text": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"text"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "image",
|
||||||
|
"default": "image"
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"type": "string",
|
||||||
|
"contentEncoding": "base64"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"data"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ToolCallDelta"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
"TokenLogProbs": {
|
"TokenLogProbs": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -4250,6 +4287,11 @@
|
||||||
"ToolCallDelta": {
|
"ToolCallDelta": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "tool_call",
|
||||||
|
"default": "tool_call"
|
||||||
|
},
|
||||||
"content": {
|
"content": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
{
|
{
|
||||||
|
@ -4266,6 +4308,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
|
"type",
|
||||||
"content",
|
"content",
|
||||||
"parse_status"
|
"parse_status"
|
||||||
]
|
]
|
||||||
|
@ -4275,8 +4318,8 @@
|
||||||
"enum": [
|
"enum": [
|
||||||
"started",
|
"started",
|
||||||
"in_progress",
|
"in_progress",
|
||||||
"failure",
|
"failed",
|
||||||
"success"
|
"succeeded"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"CompletionRequest": {
|
"CompletionRequest": {
|
||||||
|
@ -4777,18 +4820,16 @@
|
||||||
"step_id": {
|
"step_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"text_delta": {
|
"delta": {
|
||||||
"type": "string"
|
"$ref": "#/components/schemas/ContentDelta"
|
||||||
},
|
|
||||||
"tool_call_delta": {
|
|
||||||
"$ref": "#/components/schemas/ToolCallDelta"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"event_type",
|
"event_type",
|
||||||
"step_type",
|
"step_type",
|
||||||
"step_id"
|
"step_id",
|
||||||
|
"delta"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"AgentTurnResponseStepStartPayload": {
|
"AgentTurnResponseStepStartPayload": {
|
||||||
|
@ -8758,6 +8799,10 @@
|
||||||
"name": "CompletionResponseStreamChunk",
|
"name": "CompletionResponseStreamChunk",
|
||||||
"description": "streamed completion response.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/CompletionResponseStreamChunk\" />"
|
"description": "streamed completion response.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/CompletionResponseStreamChunk\" />"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ContentDelta",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ContentDelta\" />"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "CreateAgentRequest",
|
"name": "CreateAgentRequest",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/CreateAgentRequest\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/CreateAgentRequest\" />"
|
||||||
|
@ -9392,6 +9437,7 @@
|
||||||
"CompletionRequest",
|
"CompletionRequest",
|
||||||
"CompletionResponse",
|
"CompletionResponse",
|
||||||
"CompletionResponseStreamChunk",
|
"CompletionResponseStreamChunk",
|
||||||
|
"ContentDelta",
|
||||||
"CreateAgentRequest",
|
"CreateAgentRequest",
|
||||||
"CreateAgentSessionRequest",
|
"CreateAgentSessionRequest",
|
||||||
"CreateAgentTurnRequest",
|
"CreateAgentTurnRequest",
|
||||||
|
|
|
@ -150,6 +150,8 @@ components:
|
||||||
AgentTurnResponseStepProgressPayload:
|
AgentTurnResponseStepProgressPayload:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
delta:
|
||||||
|
$ref: '#/components/schemas/ContentDelta'
|
||||||
event_type:
|
event_type:
|
||||||
const: step_progress
|
const: step_progress
|
||||||
default: step_progress
|
default: step_progress
|
||||||
|
@ -163,14 +165,11 @@ components:
|
||||||
- shield_call
|
- shield_call
|
||||||
- memory_retrieval
|
- memory_retrieval
|
||||||
type: string
|
type: string
|
||||||
text_delta:
|
|
||||||
type: string
|
|
||||||
tool_call_delta:
|
|
||||||
$ref: '#/components/schemas/ToolCallDelta'
|
|
||||||
required:
|
required:
|
||||||
- event_type
|
- event_type
|
||||||
- step_type
|
- step_type
|
||||||
- step_id
|
- step_id
|
||||||
|
- delta
|
||||||
type: object
|
type: object
|
||||||
AgentTurnResponseStepStartPayload:
|
AgentTurnResponseStepStartPayload:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
@ -462,9 +461,7 @@ components:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
delta:
|
delta:
|
||||||
oneOf:
|
$ref: '#/components/schemas/ContentDelta'
|
||||||
- type: string
|
|
||||||
- $ref: '#/components/schemas/ToolCallDelta'
|
|
||||||
event_type:
|
event_type:
|
||||||
$ref: '#/components/schemas/ChatCompletionResponseEventType'
|
$ref: '#/components/schemas/ChatCompletionResponseEventType'
|
||||||
logprobs:
|
logprobs:
|
||||||
|
@ -571,6 +568,34 @@ components:
|
||||||
- delta
|
- delta
|
||||||
title: streamed completion response.
|
title: streamed completion response.
|
||||||
type: object
|
type: object
|
||||||
|
ContentDelta:
|
||||||
|
oneOf:
|
||||||
|
- additionalProperties: false
|
||||||
|
properties:
|
||||||
|
text:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
const: text
|
||||||
|
default: text
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- text
|
||||||
|
type: object
|
||||||
|
- additionalProperties: false
|
||||||
|
properties:
|
||||||
|
data:
|
||||||
|
contentEncoding: base64
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
const: image
|
||||||
|
default: image
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- data
|
||||||
|
type: object
|
||||||
|
- $ref: '#/components/schemas/ToolCallDelta'
|
||||||
CreateAgentRequest:
|
CreateAgentRequest:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -2664,7 +2689,12 @@ components:
|
||||||
- $ref: '#/components/schemas/ToolCall'
|
- $ref: '#/components/schemas/ToolCall'
|
||||||
parse_status:
|
parse_status:
|
||||||
$ref: '#/components/schemas/ToolCallParseStatus'
|
$ref: '#/components/schemas/ToolCallParseStatus'
|
||||||
|
type:
|
||||||
|
const: tool_call
|
||||||
|
default: tool_call
|
||||||
|
type: string
|
||||||
required:
|
required:
|
||||||
|
- type
|
||||||
- content
|
- content
|
||||||
- parse_status
|
- parse_status
|
||||||
type: object
|
type: object
|
||||||
|
@ -2672,8 +2702,8 @@ components:
|
||||||
enum:
|
enum:
|
||||||
- started
|
- started
|
||||||
- in_progress
|
- in_progress
|
||||||
- failure
|
- failed
|
||||||
- success
|
- succeeded
|
||||||
type: string
|
type: string
|
||||||
ToolChoice:
|
ToolChoice:
|
||||||
enum:
|
enum:
|
||||||
|
@ -2888,8 +2918,8 @@ components:
|
||||||
content:
|
content:
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
$ref: '#/components/schemas/InterleavedContent'
|
||||||
role:
|
role:
|
||||||
const: ipython
|
const: tool
|
||||||
default: ipython
|
default: tool
|
||||||
type: string
|
type: string
|
||||||
tool_name:
|
tool_name:
|
||||||
oneOf:
|
oneOf:
|
||||||
|
@ -5500,6 +5530,8 @@ tags:
|
||||||
<SchemaDefinition schemaRef="#/components/schemas/CompletionResponseStreamChunk"
|
<SchemaDefinition schemaRef="#/components/schemas/CompletionResponseStreamChunk"
|
||||||
/>'
|
/>'
|
||||||
name: CompletionResponseStreamChunk
|
name: CompletionResponseStreamChunk
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/ContentDelta" />
|
||||||
|
name: ContentDelta
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CreateAgentRequest"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/CreateAgentRequest"
|
||||||
/>
|
/>
|
||||||
name: CreateAgentRequest
|
name: CreateAgentRequest
|
||||||
|
@ -5939,6 +5971,7 @@ x-tagGroups:
|
||||||
- CompletionRequest
|
- CompletionRequest
|
||||||
- CompletionResponse
|
- CompletionResponse
|
||||||
- CompletionResponseStreamChunk
|
- CompletionResponseStreamChunk
|
||||||
|
- ContentDelta
|
||||||
- CreateAgentRequest
|
- CreateAgentRequest
|
||||||
- CreateAgentSessionRequest
|
- CreateAgentSessionRequest
|
||||||
- CreateAgentTurnRequest
|
- CreateAgentTurnRequest
|
||||||
|
|
|
@ -22,12 +22,11 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolCallDelta,
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
ToolResponse,
|
ToolResponse,
|
||||||
|
@ -216,8 +215,7 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
|
|
||||||
text_delta: Optional[str] = None
|
delta: ContentDelta
|
||||||
tool_call_delta: Optional[ToolCallDelta] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -11,9 +11,13 @@ from llama_models.llama3.api.tool_utils import ToolUtils
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
||||||
|
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||||
from llama_stack.apis.inference import ToolResponseMessage
|
from llama_stack.apis.inference import ToolResponseMessage
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LogEvent:
|
class LogEvent:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -57,8 +61,11 @@ class EventLogger:
|
||||||
# since it does not produce event but instead
|
# since it does not produce event but instead
|
||||||
# a Message
|
# a Message
|
||||||
if isinstance(chunk, ToolResponseMessage):
|
if isinstance(chunk, ToolResponseMessage):
|
||||||
yield chunk, LogEvent(
|
yield (
|
||||||
role="CustomTool", content=chunk.content, color="grey"
|
chunk,
|
||||||
|
LogEvent(
|
||||||
|
role="CustomTool", content=chunk.content, color="grey"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -80,14 +87,20 @@ class EventLogger:
|
||||||
):
|
):
|
||||||
violation = event.payload.step_details.violation
|
violation = event.payload.step_details.violation
|
||||||
if not violation:
|
if not violation:
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
role=step_type, content="No Violation", color="magenta"
|
event,
|
||||||
|
LogEvent(
|
||||||
|
role=step_type, content="No Violation", color="magenta"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
role=step_type,
|
event,
|
||||||
content=f"{violation.metadata} {violation.user_message}",
|
LogEvent(
|
||||||
color="red",
|
role=step_type,
|
||||||
|
content=f"{violation.metadata} {violation.user_message}",
|
||||||
|
color="red",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle inference
|
# handle inference
|
||||||
|
@ -95,8 +108,11 @@ class EventLogger:
|
||||||
if stream:
|
if stream:
|
||||||
if event_type == EventType.step_start.value:
|
if event_type == EventType.step_start.value:
|
||||||
# TODO: Currently this event is never received
|
# TODO: Currently this event is never received
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
role=step_type, content="", end="", color="yellow"
|
event,
|
||||||
|
LogEvent(
|
||||||
|
role=step_type, content="", end="", color="yellow"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif event_type == EventType.step_progress.value:
|
elif event_type == EventType.step_progress.value:
|
||||||
# HACK: if previous was not step/event was not inference's step_progress
|
# HACK: if previous was not step/event was not inference's step_progress
|
||||||
|
@ -107,24 +123,34 @@ class EventLogger:
|
||||||
previous_event_type != EventType.step_progress.value
|
previous_event_type != EventType.step_progress.value
|
||||||
and previous_step_type != StepType.inference
|
and previous_step_type != StepType.inference
|
||||||
):
|
):
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
role=step_type, content="", end="", color="yellow"
|
event,
|
||||||
|
LogEvent(
|
||||||
|
role=step_type, content="", end="", color="yellow"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.payload.tool_call_delta:
|
delta = event.payload.delta
|
||||||
if isinstance(event.payload.tool_call_delta.content, str):
|
if delta.type == "tool_call":
|
||||||
yield event, LogEvent(
|
if delta.parse_status == ToolCallParseStatus.succeeded:
|
||||||
role=None,
|
yield (
|
||||||
content=event.payload.tool_call_delta.content,
|
event,
|
||||||
end="",
|
LogEvent(
|
||||||
color="cyan",
|
role=None,
|
||||||
|
content=delta.content,
|
||||||
|
end="",
|
||||||
|
color="cyan",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
role=None,
|
event,
|
||||||
content=event.payload.text_delta,
|
LogEvent(
|
||||||
end="",
|
role=None,
|
||||||
color="yellow",
|
content=delta.text,
|
||||||
|
end="",
|
||||||
|
color="yellow",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# step_complete
|
# step_complete
|
||||||
|
@ -140,10 +166,13 @@ class EventLogger:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
content = response.content
|
content = response.content
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
role=step_type,
|
event,
|
||||||
content=content,
|
LogEvent(
|
||||||
color="yellow",
|
role=step_type,
|
||||||
|
content=content,
|
||||||
|
color="yellow",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle tool_execution
|
# handle tool_execution
|
||||||
|
@ -155,16 +184,22 @@ class EventLogger:
|
||||||
):
|
):
|
||||||
details = event.payload.step_details
|
details = event.payload.step_details
|
||||||
for t in details.tool_calls:
|
for t in details.tool_calls:
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
role=step_type,
|
event,
|
||||||
content=f"Tool:{t.tool_name} Args:{t.arguments}",
|
LogEvent(
|
||||||
color="green",
|
role=step_type,
|
||||||
|
content=f"Tool:{t.tool_name} Args:{t.arguments}",
|
||||||
|
color="green",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for r in details.tool_responses:
|
for r in details.tool_responses:
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
role=step_type,
|
event,
|
||||||
content=f"Tool:{r.tool_name} Response:{r.content}",
|
LogEvent(
|
||||||
color="green",
|
role=step_type,
|
||||||
|
content=f"Tool:{r.tool_name} Response:{r.content}",
|
||||||
|
color="green",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -172,15 +207,16 @@ class EventLogger:
|
||||||
and event_type == EventType.step_complete.value
|
and event_type == EventType.step_complete.value
|
||||||
):
|
):
|
||||||
details = event.payload.step_details
|
details = event.payload.step_details
|
||||||
inserted_context = interleaved_text_media_as_str(
|
inserted_context = interleaved_content_as_str(details.inserted_context)
|
||||||
details.inserted_context
|
|
||||||
)
|
|
||||||
content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}"
|
content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}"
|
||||||
|
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
role=step_type,
|
event,
|
||||||
content=content,
|
LogEvent(
|
||||||
color="cyan",
|
role=step_type,
|
||||||
|
content=content,
|
||||||
|
color="cyan",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
previous_event_type = event_type
|
previous_event_type = event_type
|
||||||
|
|
|
@ -5,10 +5,12 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
from enum import Enum
|
||||||
from typing import Annotated, List, Literal, Optional, Union
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema
|
from llama_models.llama3.api.datatypes import ToolCall
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, register_schema
|
||||||
from pydantic import BaseModel, Field, field_serializer, model_validator
|
from pydantic import BaseModel, Field, field_serializer, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,3 +62,42 @@ InterleavedContent = register_schema(
|
||||||
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
|
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
|
||||||
name="InterleavedContent",
|
name="InterleavedContent",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextDelta(BaseModel):
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDelta(BaseModel):
|
||||||
|
type: Literal["image"] = "image"
|
||||||
|
data: bytes
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolCallParseStatus(Enum):
|
||||||
|
started = "started"
|
||||||
|
in_progress = "in_progress"
|
||||||
|
failed = "failed"
|
||||||
|
succeeded = "succeeded"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolCallDelta(BaseModel):
|
||||||
|
type: Literal["tool_call"] = "tool_call"
|
||||||
|
|
||||||
|
# you either send an in-progress tool call so the client can stream a long
|
||||||
|
# code generation or you send the final parsed tool call at the end of the
|
||||||
|
# stream
|
||||||
|
content: Union[str, ToolCall]
|
||||||
|
parse_status: ToolCallParseStatus
|
||||||
|
|
||||||
|
|
||||||
|
# streaming completions send a stream of ContentDeltas
|
||||||
|
ContentDelta = register_schema(
|
||||||
|
Annotated[
|
||||||
|
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
],
|
||||||
|
name="ContentDelta",
|
||||||
|
)
|
||||||
|
|
|
@ -29,7 +29,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
@ -147,26 +147,12 @@ class ChatCompletionResponseEventType(Enum):
|
||||||
progress = "progress"
|
progress = "progress"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolCallParseStatus(Enum):
|
|
||||||
started = "started"
|
|
||||||
in_progress = "in_progress"
|
|
||||||
failure = "failure"
|
|
||||||
success = "success"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolCallDelta(BaseModel):
|
|
||||||
content: Union[str, ToolCall]
|
|
||||||
parse_status: ToolCallParseStatus
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponseEvent(BaseModel):
|
class ChatCompletionResponseEvent(BaseModel):
|
||||||
"""Chat completion response event."""
|
"""Chat completion response event."""
|
||||||
|
|
||||||
event_type: ChatCompletionResponseEventType
|
event_type: ChatCompletionResponseEventType
|
||||||
delta: Union[str, ToolCallDelta]
|
delta: ContentDelta
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: Optional[StopReason] = None
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
@ -14,14 +12,12 @@ from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
BuildConfig,
|
BuildConfig,
|
||||||
DistributionSpec,
|
DistributionSpec,
|
||||||
Provider,
|
Provider,
|
||||||
StackRunConfig,
|
StackRunConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
@ -296,6 +292,7 @@ class StackBuild(Subcommand):
|
||||||
/ f"templates/{template_name}/run.yaml"
|
/ f"templates/{template_name}/run.yaml"
|
||||||
)
|
)
|
||||||
with importlib.resources.as_file(template_path) as path:
|
with importlib.resources.as_file(template_path) as path:
|
||||||
|
run_config_file = build_dir / f"{build_config.name}-run.yaml"
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
# Find all ${env.VARIABLE} patterns
|
# Find all ${env.VARIABLE} patterns
|
||||||
cprint("Build Successful!", color="green")
|
cprint("Build Successful!", color="green")
|
||||||
|
|
|
@ -9,12 +9,10 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import queue
|
|
||||||
import threading
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
|
from typing import Any, get_args, get_origin, Optional, TypeVar
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -64,71 +62,6 @@ def in_notebook():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def stream_across_asyncio_run_boundary(
|
|
||||||
async_gen_maker,
|
|
||||||
pool_executor: ThreadPoolExecutor,
|
|
||||||
path: Optional[str] = None,
|
|
||||||
provider_data: Optional[dict[str, Any]] = None,
|
|
||||||
) -> Generator[T, None, None]:
|
|
||||||
result_queue = queue.Queue()
|
|
||||||
stop_event = threading.Event()
|
|
||||||
|
|
||||||
async def consumer():
|
|
||||||
# make sure we make the generator in the event loop context
|
|
||||||
gen = await async_gen_maker()
|
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
|
||||||
if provider_data:
|
|
||||||
set_request_provider_data(
|
|
||||||
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
async for item in await gen:
|
|
||||||
result_queue.put(item)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error in generator {e}")
|
|
||||||
result_queue.put(e)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
return
|
|
||||||
finally:
|
|
||||||
result_queue.put(StopIteration)
|
|
||||||
stop_event.set()
|
|
||||||
await end_trace()
|
|
||||||
|
|
||||||
def run_async():
|
|
||||||
# Run our own loop to avoid double async generator cleanup which is done
|
|
||||||
# by asyncio.run()
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
|
||||||
task = loop.create_task(consumer())
|
|
||||||
loop.run_until_complete(task)
|
|
||||||
finally:
|
|
||||||
# Handle pending tasks like a generator's athrow()
|
|
||||||
pending = asyncio.all_tasks(loop)
|
|
||||||
if pending:
|
|
||||||
loop.run_until_complete(
|
|
||||||
asyncio.gather(*pending, return_exceptions=True)
|
|
||||||
)
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
future = pool_executor.submit(run_async)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# yield results as they come in
|
|
||||||
while not stop_event.is_set() or not result_queue.empty():
|
|
||||||
try:
|
|
||||||
item = result_queue.get(timeout=0.1)
|
|
||||||
if item is StopIteration:
|
|
||||||
break
|
|
||||||
if isinstance(item, Exception):
|
|
||||||
raise item
|
|
||||||
yield item
|
|
||||||
except queue.Empty:
|
|
||||||
continue
|
|
||||||
finally:
|
|
||||||
future.result()
|
|
||||||
|
|
||||||
|
|
||||||
def convert_pydantic_to_json_value(value: Any) -> Any:
|
def convert_pydantic_to_json_value(value: Any) -> Any:
|
||||||
if isinstance(value, Enum):
|
if isinstance(value, Enum):
|
||||||
return value.value
|
return value.value
|
||||||
|
@ -184,7 +117,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
config_path_or_template_name, custom_provider_registry
|
config_path_or_template_name, custom_provider_registry, provider_data
|
||||||
)
|
)
|
||||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||||
self.skip_logger_removal = skip_logger_removal
|
self.skip_logger_removal = skip_logger_removal
|
||||||
|
@ -210,39 +143,30 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
root_logger.removeHandler(handler)
|
root_logger.removeHandler(handler)
|
||||||
print(f"Removed handler {handler.__class__.__name__} from root logger")
|
print(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||||
|
|
||||||
def _get_path(
|
|
||||||
self,
|
|
||||||
cast_to: Any,
|
|
||||||
options: Any,
|
|
||||||
*,
|
|
||||||
stream=False,
|
|
||||||
stream_cls=None,
|
|
||||||
):
|
|
||||||
return options.url
|
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
path = self._get_path(*args, **kwargs)
|
|
||||||
if kwargs.get("stream"):
|
if kwargs.get("stream"):
|
||||||
return stream_across_asyncio_run_boundary(
|
# NOTE: We are using AsyncLlamaStackClient under the hood
|
||||||
lambda: self.async_client.request(*args, **kwargs),
|
# A new event loop is needed to convert the AsyncStream
|
||||||
self.pool_executor,
|
# from async client into SyncStream return type for streaming
|
||||||
path=path,
|
loop = asyncio.new_event_loop()
|
||||||
provider_data=self.provider_data,
|
asyncio.set_event_loop(loop)
|
||||||
)
|
|
||||||
else:
|
|
||||||
|
|
||||||
async def _traced_request():
|
def sync_generator():
|
||||||
if self.provider_data:
|
|
||||||
set_request_provider_data(
|
|
||||||
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
|
|
||||||
)
|
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
|
||||||
try:
|
try:
|
||||||
return await self.async_client.request(*args, **kwargs)
|
async_stream = loop.run_until_complete(
|
||||||
|
self.async_client.request(*args, **kwargs)
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
chunk = loop.run_until_complete(async_stream.__anext__())
|
||||||
|
yield chunk
|
||||||
|
except StopAsyncIteration:
|
||||||
|
pass
|
||||||
finally:
|
finally:
|
||||||
await end_trace()
|
loop.close()
|
||||||
|
|
||||||
return asyncio.run(_traced_request())
|
return sync_generator()
|
||||||
|
else:
|
||||||
|
return asyncio.run(self.async_client.request(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
@ -250,9 +174,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self,
|
self,
|
||||||
config_path_or_template_name: str,
|
config_path_or_template_name: str,
|
||||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||||
|
provider_data: Optional[dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
# of our logs are intended for server-side usage
|
# of our logs are intended for server-side usage
|
||||||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||||
|
@ -273,6 +197,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.config_path_or_template_name = config_path_or_template_name
|
self.config_path_or_template_name = config_path_or_template_name
|
||||||
self.config = config
|
self.config = config
|
||||||
self.custom_provider_registry = custom_provider_registry
|
self.custom_provider_registry = custom_provider_registry
|
||||||
|
self.provider_data = provider_data
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
try:
|
try:
|
||||||
|
@ -329,17 +254,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if not self.endpoint_impls:
|
if not self.endpoint_impls:
|
||||||
raise ValueError("Client not initialized")
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
|
if self.provider_data:
|
||||||
|
set_request_provider_data(
|
||||||
|
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
|
||||||
|
)
|
||||||
|
await start_trace(options.url, {"__location__": "library_client"})
|
||||||
if stream:
|
if stream:
|
||||||
return self._call_streaming(
|
response = await self._call_streaming(
|
||||||
cast_to=cast_to,
|
cast_to=cast_to,
|
||||||
options=options,
|
options=options,
|
||||||
stream_cls=stream_cls,
|
stream_cls=stream_cls,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await self._call_non_streaming(
|
response = await self._call_non_streaming(
|
||||||
cast_to=cast_to,
|
cast_to=cast_to,
|
||||||
options=options,
|
options=options,
|
||||||
)
|
)
|
||||||
|
await end_trace()
|
||||||
|
return response
|
||||||
|
|
||||||
async def _call_non_streaming(
|
async def _call_non_streaming(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v4"
|
KEY_VERSION = "v5"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,12 @@ from llama_stack.apis.agents import (
|
||||||
ToolExecutionStep,
|
ToolExecutionStep,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.content_types import TextContentItem, URL
|
from llama_stack.apis.common.content_types import (
|
||||||
|
TextContentItem,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
URL,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
|
@ -49,8 +54,6 @@ from llama_stack.apis.inference import (
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
StopReason,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolResponse,
|
ToolResponse,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
|
@ -411,8 +414,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
tool_call_delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
parse_status=ToolCallParseStatus.success,
|
parse_status=ToolCallParseStatus.succeeded,
|
||||||
content=ToolCall(
|
content=ToolCall(
|
||||||
call_id="",
|
call_id="",
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
|
@ -507,8 +510,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
delta = event.delta
|
delta = event.delta
|
||||||
if isinstance(delta, ToolCallDelta):
|
if delta.type == "tool_call":
|
||||||
if delta.parse_status == ToolCallParseStatus.success:
|
if delta.parse_status == ToolCallParseStatus.succeeded:
|
||||||
tool_calls.append(delta.content)
|
tool_calls.append(delta.content)
|
||||||
if stream:
|
if stream:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
text_delta="",
|
delta=delta,
|
||||||
tool_call_delta=delta,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(delta, str):
|
elif delta.type == "text":
|
||||||
content += delta
|
content += delta.text
|
||||||
if stream and event.stop_reason is None:
|
if stream and event.stop_reason is None:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
text_delta=event.delta,
|
delta=delta,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
TextDelta,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -32,8 +37,6 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl(
|
||||||
]
|
]
|
||||||
|
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=text,
|
delta=TextDelta(text=text),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
logprobs=logprobs if request.logprobs else None,
|
logprobs=logprobs if request.logprobs else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
stop_reason=StopReason.out_of_tokens,
|
stop_reason=StopReason.out_of_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl(
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
delta = text
|
delta = TextDelta(text=text)
|
||||||
|
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
|
@ -428,7 +431,7 @@ class MetaReferenceInferenceImpl(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
content="",
|
content="",
|
||||||
parse_status=ToolCallParseStatus.failure,
|
parse_status=ToolCallParseStatus.failed,
|
||||||
),
|
),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
|
@ -440,7 +443,7 @@ class MetaReferenceInferenceImpl(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
content=tool_call,
|
content=tool_call,
|
||||||
parse_status=ToolCallParseStatus.success,
|
parse_status=ToolCallParseStatus.succeeded,
|
||||||
),
|
),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
|
@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -30,13 +30,10 @@ from llama_stack.apis.telemetry import (
|
||||||
Trace,
|
Trace,
|
||||||
UnstructuredLogEvent,
|
UnstructuredLogEvent,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
||||||
ConsoleSpanProcessor,
|
ConsoleSpanProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
|
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
|
||||||
SQLiteSpanProcessor,
|
SQLiteSpanProcessor,
|
||||||
)
|
)
|
||||||
|
@ -52,6 +49,7 @@ _GLOBAL_STORAGE = {
|
||||||
"up_down_counters": {},
|
"up_down_counters": {},
|
||||||
}
|
}
|
||||||
_global_lock = threading.Lock()
|
_global_lock = threading.Lock()
|
||||||
|
_TRACER_PROVIDER = None
|
||||||
|
|
||||||
|
|
||||||
def string_to_trace_id(s: str) -> int:
|
def string_to_trace_id(s: str) -> int:
|
||||||
|
@ -80,31 +78,34 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = TracerProvider(resource=resource)
|
global _TRACER_PROVIDER
|
||||||
trace.set_tracer_provider(provider)
|
if _TRACER_PROVIDER is None:
|
||||||
if TelemetrySink.OTEL in self.config.sinks:
|
provider = TracerProvider(resource=resource)
|
||||||
otlp_exporter = OTLPSpanExporter(
|
trace.set_tracer_provider(provider)
|
||||||
endpoint=self.config.otel_endpoint,
|
_TRACER_PROVIDER = provider
|
||||||
)
|
if TelemetrySink.OTEL in self.config.sinks:
|
||||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
otlp_exporter = OTLPSpanExporter(
|
||||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
|
||||||
metric_reader = PeriodicExportingMetricReader(
|
|
||||||
OTLPMetricExporter(
|
|
||||||
endpoint=self.config.otel_endpoint,
|
endpoint=self.config.otel_endpoint,
|
||||||
)
|
)
|
||||||
)
|
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||||
metric_provider = MeterProvider(
|
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||||
resource=resource, metric_readers=[metric_reader]
|
metric_reader = PeriodicExportingMetricReader(
|
||||||
)
|
OTLPMetricExporter(
|
||||||
metrics.set_meter_provider(metric_provider)
|
endpoint=self.config.otel_endpoint,
|
||||||
self.meter = metrics.get_meter(__name__)
|
)
|
||||||
if TelemetrySink.SQLITE in self.config.sinks:
|
)
|
||||||
trace.get_tracer_provider().add_span_processor(
|
metric_provider = MeterProvider(
|
||||||
SQLiteSpanProcessor(self.config.sqlite_db_path)
|
resource=resource, metric_readers=[metric_reader]
|
||||||
)
|
)
|
||||||
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
metrics.set_meter_provider(metric_provider)
|
||||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
self.meter = metrics.get_meter(__name__)
|
||||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
if TelemetrySink.SQLITE in self.config.sinks:
|
||||||
|
trace.get_tracer_provider().add_span_processor(
|
||||||
|
SQLiteSpanProcessor(self.config.sqlite_db_path)
|
||||||
|
)
|
||||||
|
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||||
|
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||||
|
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||||
self._lock = _global_lock
|
self._lock = _global_lock
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
|
@ -30,6 +30,11 @@ from groq.types.shared.function_definition import FunctionDefinition
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
TextDelta,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -40,8 +45,6 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -162,7 +165,7 @@ def convert_chat_completion_response(
|
||||||
|
|
||||||
|
|
||||||
def _map_finish_reason_to_stop_reason(
|
def _map_finish_reason_to_stop_reason(
|
||||||
finish_reason: Literal["stop", "length", "tool_calls"]
|
finish_reason: Literal["stop", "length", "tool_calls"],
|
||||||
) -> StopReason:
|
) -> StopReason:
|
||||||
"""
|
"""
|
||||||
Convert a Groq chat completion finish_reason to a StopReason.
|
Convert a Groq chat completion finish_reason to a StopReason.
|
||||||
|
@ -185,7 +188,6 @@ def _map_finish_reason_to_stop_reason(
|
||||||
async def convert_chat_completion_response_stream(
|
async def convert_chat_completion_response_stream(
|
||||||
stream: Stream[ChatCompletionChunk],
|
stream: Stream[ChatCompletionChunk],
|
||||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
|
|
||||||
event_type = ChatCompletionResponseEventType.start
|
event_type = ChatCompletionResponseEventType.start
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
choice = chunk.choices[0]
|
choice = chunk.choices[0]
|
||||||
|
@ -194,7 +196,7 @@ async def convert_chat_completion_response_stream(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
delta=choice.delta.content or "",
|
delta=TextDelta(text=choice.delta.content or ""),
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||||
)
|
)
|
||||||
|
@ -213,7 +215,7 @@ async def convert_chat_completion_response_stream(
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
content=tool_call,
|
content=tool_call,
|
||||||
parse_status=ToolCallParseStatus.success,
|
parse_status=ToolCallParseStatus.succeeded,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -221,7 +223,7 @@ async def convert_chat_completion_response_stream(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
delta=choice.delta.content or "",
|
delta=TextDelta(text=choice.delta.content or ""),
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -34,6 +34,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||||
from openai.types.completion import Completion as OpenAICompletion
|
from openai.types.completion import Completion as OpenAICompletion
|
||||||
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
|
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
TextDelta,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -48,8 +53,6 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
@ -432,69 +435,6 @@ async def convert_openai_chat_completion_stream(
|
||||||
"""
|
"""
|
||||||
Convert a stream of OpenAI chat completion chunks into a stream
|
Convert a stream of OpenAI chat completion chunks into a stream
|
||||||
of ChatCompletionResponseStreamChunk.
|
of ChatCompletionResponseStreamChunk.
|
||||||
|
|
||||||
OpenAI ChatCompletionChunk:
|
|
||||||
choices: List[Choice]
|
|
||||||
|
|
||||||
OpenAI Choice: # different from the non-streamed Choice
|
|
||||||
delta: ChoiceDelta
|
|
||||||
finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
|
|
||||||
logprobs: Optional[ChoiceLogprobs]
|
|
||||||
|
|
||||||
OpenAI ChoiceDelta:
|
|
||||||
content: Optional[str]
|
|
||||||
role: Optional[Literal["system", "user", "assistant", "tool"]]
|
|
||||||
tool_calls: Optional[List[ChoiceDeltaToolCall]]
|
|
||||||
|
|
||||||
OpenAI ChoiceDeltaToolCall:
|
|
||||||
index: int
|
|
||||||
id: Optional[str]
|
|
||||||
function: Optional[ChoiceDeltaToolCallFunction]
|
|
||||||
type: Optional[Literal["function"]]
|
|
||||||
|
|
||||||
OpenAI ChoiceDeltaToolCallFunction:
|
|
||||||
name: Optional[str]
|
|
||||||
arguments: Optional[str]
|
|
||||||
|
|
||||||
->
|
|
||||||
|
|
||||||
ChatCompletionResponseStreamChunk:
|
|
||||||
event: ChatCompletionResponseEvent
|
|
||||||
|
|
||||||
ChatCompletionResponseEvent:
|
|
||||||
event_type: ChatCompletionResponseEventType
|
|
||||||
delta: Union[str, ToolCallDelta]
|
|
||||||
logprobs: Optional[List[TokenLogProbs]]
|
|
||||||
stop_reason: Optional[StopReason]
|
|
||||||
|
|
||||||
ChatCompletionResponseEventType:
|
|
||||||
start = "start"
|
|
||||||
progress = "progress"
|
|
||||||
complete = "complete"
|
|
||||||
|
|
||||||
ToolCallDelta:
|
|
||||||
content: Union[str, ToolCall]
|
|
||||||
parse_status: ToolCallParseStatus
|
|
||||||
|
|
||||||
ToolCall:
|
|
||||||
call_id: str
|
|
||||||
tool_name: str
|
|
||||||
arguments: str
|
|
||||||
|
|
||||||
ToolCallParseStatus:
|
|
||||||
started = "started"
|
|
||||||
in_progress = "in_progress"
|
|
||||||
failure = "failure"
|
|
||||||
success = "success"
|
|
||||||
|
|
||||||
TokenLogProbs:
|
|
||||||
logprobs_by_token: Dict[str, float]
|
|
||||||
- token, logprob
|
|
||||||
|
|
||||||
StopReason:
|
|
||||||
end_of_turn = "end_of_turn"
|
|
||||||
end_of_message = "end_of_message"
|
|
||||||
out_of_tokens = "out_of_tokens"
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
|
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
|
||||||
|
@ -543,7 +483,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=next(event_type),
|
event_type=next(event_type),
|
||||||
delta=choice.delta.content,
|
delta=TextDelta(text=choice.delta.content),
|
||||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -561,7 +501,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
event_type=next(event_type),
|
event_type=next(event_type),
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
|
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
|
||||||
parse_status=ToolCallParseStatus.success,
|
parse_status=ToolCallParseStatus.succeeded,
|
||||||
),
|
),
|
||||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||||
)
|
)
|
||||||
|
@ -570,7 +510,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=next(event_type),
|
event_type=next(event_type),
|
||||||
delta=choice.delta.content or "", # content is not optional
|
delta=TextDelta(text=choice.delta.content or ""),
|
||||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -578,7 +518,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -653,18 +593,6 @@ def _convert_openai_completion_logprobs(
|
||||||
) -> Optional[List[TokenLogProbs]]:
|
) -> Optional[List[TokenLogProbs]]:
|
||||||
"""
|
"""
|
||||||
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
|
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
|
||||||
|
|
||||||
OpenAI CompletionLogprobs:
|
|
||||||
text_offset: Optional[List[int]]
|
|
||||||
token_logprobs: Optional[List[float]]
|
|
||||||
tokens: Optional[List[str]]
|
|
||||||
top_logprobs: Optional[List[Dict[str, float]]]
|
|
||||||
|
|
||||||
->
|
|
||||||
|
|
||||||
TokenLogProbs:
|
|
||||||
logprobs_by_token: Dict[str, float]
|
|
||||||
- token, logprob
|
|
||||||
"""
|
"""
|
||||||
if not logprobs:
|
if not logprobs:
|
||||||
return None
|
return None
|
||||||
|
@ -679,28 +607,6 @@ def convert_openai_completion_choice(
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
"""
|
"""
|
||||||
Convert an OpenAI Completion Choice into a CompletionResponse.
|
Convert an OpenAI Completion Choice into a CompletionResponse.
|
||||||
|
|
||||||
OpenAI Completion Choice:
|
|
||||||
text: str
|
|
||||||
finish_reason: str
|
|
||||||
logprobs: Optional[ChoiceLogprobs]
|
|
||||||
|
|
||||||
->
|
|
||||||
|
|
||||||
CompletionResponse:
|
|
||||||
completion_message: CompletionMessage
|
|
||||||
logprobs: Optional[List[TokenLogProbs]]
|
|
||||||
|
|
||||||
CompletionMessage:
|
|
||||||
role: Literal["assistant"]
|
|
||||||
content: str | ImageMedia | List[str | ImageMedia]
|
|
||||||
stop_reason: StopReason
|
|
||||||
tool_calls: List[ToolCall]
|
|
||||||
|
|
||||||
class StopReason(Enum):
|
|
||||||
end_of_turn = "end_of_turn"
|
|
||||||
end_of_message = "end_of_message"
|
|
||||||
out_of_tokens = "out_of_tokens"
|
|
||||||
"""
|
"""
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
content=choice.text,
|
content=choice.text,
|
||||||
|
@ -715,32 +621,11 @@ async def convert_openai_completion_stream(
|
||||||
"""
|
"""
|
||||||
Convert a stream of OpenAI Completions into a stream
|
Convert a stream of OpenAI Completions into a stream
|
||||||
of ChatCompletionResponseStreamChunks.
|
of ChatCompletionResponseStreamChunks.
|
||||||
|
|
||||||
OpenAI Completion:
|
|
||||||
id: str
|
|
||||||
choices: List[OpenAICompletionChoice]
|
|
||||||
created: int
|
|
||||||
model: str
|
|
||||||
system_fingerprint: Optional[str]
|
|
||||||
usage: Optional[OpenAICompletionUsage]
|
|
||||||
|
|
||||||
OpenAI CompletionChoice:
|
|
||||||
finish_reason: str
|
|
||||||
index: int
|
|
||||||
logprobs: Optional[OpenAILogprobs]
|
|
||||||
text: str
|
|
||||||
|
|
||||||
->
|
|
||||||
|
|
||||||
CompletionResponseStreamChunk:
|
|
||||||
delta: str
|
|
||||||
stop_reason: Optional[StopReason]
|
|
||||||
logprobs: Optional[List[TokenLogProbs]]
|
|
||||||
"""
|
"""
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
choice = chunk.choices[0]
|
choice = chunk.choices[0]
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=choice.text,
|
delta=TextDelta(text=choice.text),
|
||||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||||
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
|
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,6 +18,7 @@ from llama_models.llama3.api.datatypes import (
|
||||||
|
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
@ -27,8 +28,6 @@ from llama_stack.apis.inference import (
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
@ -196,7 +195,9 @@ class TestInference:
|
||||||
1 <= len(chunks) <= 6
|
1 <= len(chunks) <= 6
|
||||||
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if chunk.delta: # if there's a token, we expect logprobs
|
if (
|
||||||
|
chunk.delta.type == "text" and chunk.delta.text
|
||||||
|
): # if there's a token, we expect logprobs
|
||||||
assert chunk.logprobs, "Logprobs should not be empty"
|
assert chunk.logprobs, "Logprobs should not be empty"
|
||||||
assert all(
|
assert all(
|
||||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
||||||
|
@ -463,7 +464,7 @@ class TestInference:
|
||||||
|
|
||||||
if "Llama3.1" in inference_model:
|
if "Llama3.1" in inference_model:
|
||||||
assert all(
|
assert all(
|
||||||
isinstance(chunk.event.delta, ToolCallDelta)
|
chunk.event.delta.type == "tool_call"
|
||||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||||
)
|
)
|
||||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||||
|
@ -474,8 +475,8 @@ class TestInference:
|
||||||
|
|
||||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||||
# assert last.event.stop_reason == expected_stop_reason
|
# assert last.event.stop_reason == expected_stop_reason
|
||||||
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
|
||||||
assert isinstance(last.event.delta.content, ToolCall)
|
assert last.event.delta.content.type == "tool_call"
|
||||||
|
|
||||||
call = last.event.delta.content
|
call = last.event.delta.content
|
||||||
assert call.tool_name == "get_weather"
|
assert call.tool_name == "get_weather"
|
||||||
|
|
|
@ -11,7 +11,13 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
|
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
from llama_stack.apis.common.content_types import (
|
||||||
|
ImageContentItem,
|
||||||
|
TextContentItem,
|
||||||
|
TextDelta,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -22,8 +28,6 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
Message,
|
Message,
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
@ -160,7 +164,7 @@ async def process_chat_completion_stream_response(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -227,7 +231,7 @@ async def process_chat_completion_stream_response(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=text,
|
delta=TextDelta(text=text),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -241,7 +245,7 @@ async def process_chat_completion_stream_response(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
content="",
|
content="",
|
||||||
parse_status=ToolCallParseStatus.failure,
|
parse_status=ToolCallParseStatus.failed,
|
||||||
),
|
),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
|
@ -253,7 +257,7 @@ async def process_chat_completion_stream_response(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
content=tool_call,
|
content=tool_call,
|
||||||
parse_status=ToolCallParseStatus.success,
|
parse_status=ToolCallParseStatus.succeeded,
|
||||||
),
|
),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
|
@ -262,7 +266,7 @@ async def process_chat_completion_stream_response(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -265,6 +265,7 @@ def chat_completion_request_to_messages(
|
||||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
For eg. for llama_3_1, add system message with the appropriate tools or
|
||||||
add user messsage for custom tools, etc.
|
add user messsage for custom tools, etc.
|
||||||
"""
|
"""
|
||||||
|
assert llama_model is not None, "llama_model is required"
|
||||||
model = resolve_model(llama_model)
|
model = resolve_model(llama_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
log.error(f"Could not resolve model {llama_model}")
|
log.error(f"Could not resolve model {llama_model}")
|
||||||
|
|
|
@ -127,7 +127,8 @@ class TraceContext:
|
||||||
def setup_logger(api: Telemetry, level: int = logging.INFO):
|
def setup_logger(api: Telemetry, level: int = logging.INFO):
|
||||||
global BACKGROUND_LOGGER
|
global BACKGROUND_LOGGER
|
||||||
|
|
||||||
BACKGROUND_LOGGER = BackgroundLogger(api)
|
if BACKGROUND_LOGGER is None:
|
||||||
|
BACKGROUND_LOGGER = BackgroundLogger(api)
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(level)
|
logger.setLevel(level)
|
||||||
logger.addHandler(TelemetryHandler())
|
logger.addHandler(TelemetryHandler())
|
||||||
|
|
|
@ -12,6 +12,11 @@ from llama_stack.providers.tests.env import get_env_or_fail
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
config.option.tbstyle = "short"
|
||||||
|
config.option.disable_warnings = True
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def provider_data():
|
def provider_data():
|
||||||
# check env for tavily secret, brave secret and inject all into provider data
|
# check env for tavily secret, brave secret and inject all into provider data
|
||||||
|
@ -29,6 +34,7 @@ def llama_stack_client(provider_data):
|
||||||
client = LlamaStackAsLibraryClient(
|
client = LlamaStackAsLibraryClient(
|
||||||
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
||||||
provider_data=provider_data,
|
provider_data=provider_data,
|
||||||
|
skip_logger_removal=True,
|
||||||
)
|
)
|
||||||
client.initialize()
|
client.initialize()
|
||||||
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
||||||
|
|
|
@ -6,9 +6,9 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack_client.lib.inference.event_logger import EventLogger
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
PROVIDER_TOOL_PROMPT_FORMAT = {
|
PROVIDER_TOOL_PROMPT_FORMAT = {
|
||||||
"remote::ollama": "python_list",
|
"remote::ollama": "python_list",
|
||||||
"remote::together": "json",
|
"remote::together": "json",
|
||||||
|
@ -39,7 +39,7 @@ def text_model_id(llama_stack_client):
|
||||||
available_models = [
|
available_models = [
|
||||||
model.identifier
|
model.identifier
|
||||||
for model in llama_stack_client.models.list()
|
for model in llama_stack_client.models.list()
|
||||||
if model.identifier.startswith("meta-llama")
|
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
||||||
]
|
]
|
||||||
assert len(available_models) > 0
|
assert len(available_models) > 0
|
||||||
return available_models[0]
|
return available_models[0]
|
||||||
|
@ -208,12 +208,9 @@ def test_text_chat_completion_streaming(
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
streamed_content = [
|
streamed_content = [
|
||||||
str(log.content.lower().strip())
|
str(chunk.event.delta.text.lower().strip()) for chunk in response
|
||||||
for log in EventLogger().log(response)
|
|
||||||
if log is not None
|
|
||||||
]
|
]
|
||||||
assert len(streamed_content) > 0
|
assert len(streamed_content) > 0
|
||||||
assert "assistant>" in streamed_content[0]
|
|
||||||
assert expected.lower() in "".join(streamed_content)
|
assert expected.lower() in "".join(streamed_content)
|
||||||
|
|
||||||
|
|
||||||
|
@ -250,17 +247,16 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
||||||
def extract_tool_invocation_content(response):
|
def extract_tool_invocation_content(response):
|
||||||
text_content: str = ""
|
text_content: str = ""
|
||||||
tool_invocation_content: str = ""
|
tool_invocation_content: str = ""
|
||||||
for log in EventLogger().log(response):
|
for chunk in response:
|
||||||
if log is None:
|
delta = chunk.event.delta
|
||||||
continue
|
if delta.type == "text":
|
||||||
if isinstance(log.content, str):
|
text_content += delta.text
|
||||||
text_content += log.content
|
elif delta.type == "tool_call":
|
||||||
elif isinstance(log.content, object):
|
if isinstance(delta.content, str):
|
||||||
if isinstance(log.content.content, str):
|
tool_invocation_content += delta.content
|
||||||
continue
|
else:
|
||||||
elif isinstance(log.content.content, object):
|
call = delta.content
|
||||||
tool_invocation_content += f"[{log.content.content.tool_name}, {log.content.content.arguments}]"
|
tool_invocation_content += f"[{call.tool_name}, {call.arguments}]"
|
||||||
|
|
||||||
return text_content, tool_invocation_content
|
return text_content, tool_invocation_content
|
||||||
|
|
||||||
|
|
||||||
|
@ -280,7 +276,6 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
|
||||||
)
|
)
|
||||||
text_content, tool_invocation_content = extract_tool_invocation_content(response)
|
text_content, tool_invocation_content = extract_tool_invocation_content(response)
|
||||||
|
|
||||||
assert "Assistant>" in text_content
|
|
||||||
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
|
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
|
||||||
|
|
||||||
|
|
||||||
|
@ -368,10 +363,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
streamed_content = [
|
streamed_content = [
|
||||||
str(log.content.lower().strip())
|
str(chunk.event.delta.text.lower().strip()) for chunk in response
|
||||||
for log in EventLogger().log(response)
|
|
||||||
if log is not None
|
|
||||||
]
|
]
|
||||||
assert len(streamed_content) > 0
|
assert len(streamed_content) > 0
|
||||||
assert "assistant>" in streamed_content[0]
|
|
||||||
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})
|
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue