move straggler files and fix some important existing bugs

This commit is contained in:
Ashwin Bharambe 2024-08-05 09:24:45 -07:00
parent 5e972ece13
commit 7890921e5c
5 changed files with 303 additions and 11 deletions

View file

@ -4,15 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import copy
import uuid
from datetime import datetime
from typing import AsyncGenerator, List, Optional
from llama_toolchain.inference.api import Inference
from llama_toolchain.safety.api import Safety
from .api.endpoints import * # noqa
import uuid
from datetime import datetime
from typing import AsyncGenerator, List, Optional
from llama_toolchain.inference.api import ChatCompletionRequest
from llama_toolchain.inference.api.datatypes import (
@ -219,13 +221,14 @@ class AgentInstance(ShieldRunnerMixin):
)
session.turns.append(turn)
yield AgenticSystemTurnResponseStreamChunk(
chunk = AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
payload=AgenticSystemTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def run_shields_wrapper(
self,
@ -388,7 +391,10 @@ class AgentInstance(ShieldRunnerMixin):
stop_reason = None
async for chunk in self.inference_api.chat_completion(req):
event = chunk.event
if event.event_type != ChatCompletionResponseEventType.progress:
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
continue
delta = event.delta
@ -439,7 +445,12 @@ class AgentInstance(ShieldRunnerMixin):
step_type=StepType.inference.value,
step_id=step_id,
step_details=InferenceStep(
step_id=step_id, turn_id=turn_id, model_response=message
# somewhere deep, we are re-assigning message or closing over some
# variable which causes message to mutate later on. fix with a
# `deepcopy` for now, but this is symptomatic of a deeper issue.
step_id=step_id,
turn_id=turn_id,
model_response=copy.deepcopy(message),
),
)
)