diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py index 9cbd1fbd2..b5ad6ae91 100644 --- a/llama_stack/apis/agents/event_logger.py +++ b/llama_stack/apis/agents/event_logger.py @@ -9,10 +9,10 @@ from typing import Optional from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tool_utils import ToolUtils -from llama_stack.apis.agents import AgentTurnResponseEventType, StepType - from termcolor import cprint +from llama_stack.apis.agents import AgentTurnResponseEventType, StepType + class LogEvent: def __init__( @@ -77,15 +77,15 @@ class EventLogger: step_type == StepType.shield_call and event_type == EventType.step_complete.value ): - response = event.payload.step_details.response - if not response.is_violation: + violation = event.payload.step_details.violation + if not violation: yield event, LogEvent( role=step_type, content="No Violation", color="magenta" ) else: yield event, LogEvent( role=step_type, - content=f"{response.violation_type} {response.violation_return_message}", + content=f"{violation.metadata} {violation.user_message}", color="red", ) diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 4d67fb4f6..8c75c6893 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -10,21 +10,14 @@ from typing import Any, AsyncGenerator import fire import httpx - -from llama_stack.distribution.datatypes import RemoteProviderConfig from pydantic import BaseModel from termcolor import cprint +from llama_stack.distribution.datatypes import RemoteProviderConfig + from .event_logger import EventLogger -from .inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseStreamChunk, - CompletionRequest, - Inference, - UserMessage, -) +from llama_stack.apis.inference import * # noqa: F403 async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference: @@ -48,7 +41,27 @@ class InferenceClient(Inference): async def completion(self, request: CompletionRequest) -> AsyncGenerator: raise NotImplementedError() - async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) async with httpx.AsyncClient() as client: async with client.stream( "POST", diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index e8ed68fa6..eab86b76b 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -248,7 +248,7 @@ class ChatAgent(ShieldRunnerMixin): self, turn_id: str, messages: List[Message], - shields: List[ShieldDefinition], + shields: List[str], touchpoint: str, ) -> AsyncGenerator: if len(shields) == 0: @@ -608,7 +608,6 @@ class ChatAgent(ShieldRunnerMixin): else: return True - print(f"{enabled_tools=}") return AgentTool.memory.value in enabled_tools def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]: