Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -86,9 +86,7 @@ class ShieldCallStep(StepCommon):
@json_schema_type
class MemoryRetrievalStep(StepCommon):
step_type: Literal[StepType.memory_retrieval.value] = (
StepType.memory_retrieval.value
)
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
vector_db_ids: str
inserted_context: InterleavedContent
@ -184,9 +182,7 @@ class AgentTurnResponseEventType(Enum):
@json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
AgentTurnResponseEventType.step_start.value
)
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
step_type: StepType
step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
@ -194,9 +190,7 @@ class AgentTurnResponseStepStartPayload(BaseModel):
@json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
AgentTurnResponseEventType.step_complete.value
)
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value
step_type: StepType
step_id: str
step_details: Step
@ -206,9 +200,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
class AgentTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
AgentTurnResponseEventType.step_progress.value
)
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value
step_type: StepType
step_id: str
@ -217,17 +209,13 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
@json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
AgentTurnResponseEventType.turn_start.value
)
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value
turn_id: str
@json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
AgentTurnResponseEventType.turn_complete.value
)
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value
turn: Turn
@ -329,9 +317,7 @@ class Agents(Protocol):
toolgroups: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET"
)
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
async def get_agents_turn(
self,
agent_id: str,

View file

@ -63,9 +63,7 @@ class EventLogger:
if isinstance(chunk, ToolResponseMessage):
yield (
chunk,
LogEvent(
role="CustomTool", content=chunk.content, color="grey"
),
LogEvent(role="CustomTool", content=chunk.content, color="grey"),
)
continue
@ -81,17 +79,12 @@ class EventLogger:
step_type = event.payload.step_type
# handle safety
if (
step_type == StepType.shield_call
and event_type == EventType.step_complete.value
):
if step_type == StepType.shield_call and event_type == EventType.step_complete.value:
violation = event.payload.step_details.violation
if not violation:
yield (
event,
LogEvent(
role=step_type, content="No Violation", color="magenta"
),
LogEvent(role=step_type, content="No Violation", color="magenta"),
)
else:
yield (
@ -110,9 +103,7 @@ class EventLogger:
# TODO: Currently this event is never received
yield (
event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
LogEvent(role=step_type, content="", end="", color="yellow"),
)
elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress
@ -125,9 +116,7 @@ class EventLogger:
):
yield (
event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
LogEvent(role=step_type, content="", end="", color="yellow"),
)
delta = event.payload.delta
@ -161,9 +150,7 @@ class EventLogger:
if event_type == EventType.step_complete.value:
response = event.payload.step_details.model_response
if response.tool_calls:
content = ToolUtils.encode_tool_call(
response.tool_calls[0], tool_prompt_format
)
content = ToolUtils.encode_tool_call(response.tool_calls[0], tool_prompt_format)
else:
content = response.content
yield (
@ -202,10 +189,7 @@ class EventLogger:
),
)
if (
step_type == StepType.memory_retrieval
and event_type == EventType.step_complete.value
):
if step_type == StepType.memory_retrieval and event_type == EventType.step_complete.value:
details = event.payload.step_details
inserted_context = interleaved_content_as_str(details.inserted_context)
content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}"