Further bug fixes

This commit is contained in:
Ashwin Bharambe 2024-09-20 15:15:57 -07:00
parent 9e16b0948b
commit e5a7001874
3 changed files with 30 additions and 18 deletions

View file

@ -9,10 +9,10 @@ from typing import Optional
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tool_utils import ToolUtils from llama_models.llama3.api.tool_utils import ToolUtils
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from termcolor import cprint from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
class LogEvent: class LogEvent:
def __init__( def __init__(
@ -77,15 +77,15 @@ class EventLogger:
step_type == StepType.shield_call step_type == StepType.shield_call
and event_type == EventType.step_complete.value and event_type == EventType.step_complete.value
): ):
response = event.payload.step_details.response violation = event.payload.step_details.violation
if not response.is_violation: if not violation:
yield event, LogEvent( yield event, LogEvent(
role=step_type, content="No Violation", color="magenta" role=step_type, content="No Violation", color="magenta"
) )
else: else:
yield event, LogEvent( yield event, LogEvent(
role=step_type, role=step_type,
content=f"{response.violation_type} {response.violation_return_message}", content=f"{violation.metadata} {violation.user_message}",
color="red", color="red",
) )

View file

@ -10,21 +10,14 @@ from typing import Any, AsyncGenerator
import fire import fire
import httpx import httpx
from llama_stack.distribution.datatypes import RemoteProviderConfig
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .event_logger import EventLogger from .event_logger import EventLogger
from .inference import ( from llama_stack.apis.inference import * # noqa: F403
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
UserMessage,
)
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference: async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
@ -48,7 +41,27 @@ class InferenceClient(Inference):
async def completion(self, request: CompletionRequest) -> AsyncGenerator: async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
async with client.stream( async with client.stream(
"POST", "POST",

View file

@ -248,7 +248,7 @@ class ChatAgent(ShieldRunnerMixin):
self, self,
turn_id: str, turn_id: str,
messages: List[Message], messages: List[Message],
shields: List[ShieldDefinition], shields: List[str],
touchpoint: str, touchpoint: str,
) -> AsyncGenerator: ) -> AsyncGenerator:
if len(shields) == 0: if len(shields) == 0:
@ -608,7 +608,6 @@ class ChatAgent(ShieldRunnerMixin):
else: else:
return True return True
print(f"{enabled_tools=}")
return AgentTool.memory.value in enabled_tools return AgentTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]: def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]: