forked from phoenix-oss/llama-stack-mirror
		
	Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
		
			
				
	
	
		
			207 lines
		
	
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			207 lines
		
	
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import json
 | |
| import logging
 | |
| import shutil
 | |
| import tempfile
 | |
| import uuid
 | |
| from typing import AsyncGenerator, List, Optional, Union
 | |
| 
 | |
| from termcolor import colored
 | |
| 
 | |
| from llama_stack.apis.agents import (
 | |
|     AgentConfig,
 | |
|     AgentCreateResponse,
 | |
|     Agents,
 | |
|     AgentSessionCreateResponse,
 | |
|     AgentStepResponse,
 | |
|     AgentToolGroup,
 | |
|     AgentTurnCreateRequest,
 | |
|     Document,
 | |
|     Session,
 | |
|     Turn,
 | |
| )
 | |
| from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
 | |
| from llama_stack.apis.safety import Safety
 | |
| from llama_stack.apis.tools import ToolGroups, ToolRuntime
 | |
| from llama_stack.apis.vector_io import VectorIO
 | |
| from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
 | |
| 
 | |
| from .agent_instance import ChatAgent
 | |
| from .config import MetaReferenceAgentsImplConfig
 | |
| 
 | |
| logger = logging.getLogger()
 | |
| logger.setLevel(logging.INFO)
 | |
| 
 | |
| 
 | |
| class MetaReferenceAgentsImpl(Agents):
 | |
|     def __init__(
 | |
|         self,
 | |
|         config: MetaReferenceAgentsImplConfig,
 | |
|         inference_api: Inference,
 | |
|         vector_io_api: VectorIO,
 | |
|         safety_api: Safety,
 | |
|         tool_runtime_api: ToolRuntime,
 | |
|         tool_groups_api: ToolGroups,
 | |
|     ):
 | |
|         self.config = config
 | |
|         self.inference_api = inference_api
 | |
|         self.vector_io_api = vector_io_api
 | |
|         self.safety_api = safety_api
 | |
|         self.tool_runtime_api = tool_runtime_api
 | |
|         self.tool_groups_api = tool_groups_api
 | |
| 
 | |
|         self.in_memory_store = InmemoryKVStoreImpl()
 | |
|         self.tempdir = tempfile.mkdtemp()
 | |
| 
 | |
|     async def initialize(self) -> None:
 | |
|         self.persistence_store = await kvstore_impl(self.config.persistence_store)
 | |
| 
 | |
|         # check if "bwrap" is available
 | |
|         if not shutil.which("bwrap"):
 | |
|             print(
 | |
|                 colored(
 | |
|                     "Warning: `bwrap` is not available. Code interpreter tool will not work correctly.",
 | |
|                     "yellow",
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|     async def create_agent(
 | |
|         self,
 | |
|         agent_config: AgentConfig,
 | |
|     ) -> AgentCreateResponse:
 | |
|         agent_id = str(uuid.uuid4())
 | |
| 
 | |
|         await self.persistence_store.set(
 | |
|             key=f"agent:{agent_id}",
 | |
|             value=agent_config.model_dump_json(),
 | |
|         )
 | |
|         return AgentCreateResponse(
 | |
|             agent_id=agent_id,
 | |
|         )
 | |
| 
 | |
|     async def get_agent(self, agent_id: str) -> ChatAgent:
 | |
|         agent_config = await self.persistence_store.get(
 | |
|             key=f"agent:{agent_id}",
 | |
|         )
 | |
|         if not agent_config:
 | |
|             raise ValueError(f"Could not find agent config for {agent_id}")
 | |
| 
 | |
|         try:
 | |
|             agent_config = json.loads(agent_config)
 | |
|         except json.JSONDecodeError as e:
 | |
|             raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
 | |
| 
 | |
|         try:
 | |
|             agent_config = AgentConfig(**agent_config)
 | |
|         except Exception as e:
 | |
|             raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
 | |
| 
 | |
|         return ChatAgent(
 | |
|             agent_id=agent_id,
 | |
|             agent_config=agent_config,
 | |
|             tempdir=self.tempdir,
 | |
|             inference_api=self.inference_api,
 | |
|             safety_api=self.safety_api,
 | |
|             vector_io_api=self.vector_io_api,
 | |
|             tool_runtime_api=self.tool_runtime_api,
 | |
|             tool_groups_api=self.tool_groups_api,
 | |
|             persistence_store=(
 | |
|                 self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|     async def create_agent_session(
 | |
|         self,
 | |
|         agent_id: str,
 | |
|         session_name: str,
 | |
|     ) -> AgentSessionCreateResponse:
 | |
|         agent = await self.get_agent(agent_id)
 | |
| 
 | |
|         session_id = await agent.create_session(session_name)
 | |
|         return AgentSessionCreateResponse(
 | |
|             session_id=session_id,
 | |
|         )
 | |
| 
 | |
|     async def create_agent_turn(
 | |
|         self,
 | |
|         agent_id: str,
 | |
|         session_id: str,
 | |
|         messages: List[
 | |
|             Union[
 | |
|                 UserMessage,
 | |
|                 ToolResponseMessage,
 | |
|             ]
 | |
|         ],
 | |
|         toolgroups: Optional[List[AgentToolGroup]] = None,
 | |
|         documents: Optional[List[Document]] = None,
 | |
|         stream: Optional[bool] = False,
 | |
|     ) -> AsyncGenerator:
 | |
|         request = AgentTurnCreateRequest(
 | |
|             agent_id=agent_id,
 | |
|             session_id=session_id,
 | |
|             messages=messages,
 | |
|             stream=True,
 | |
|             toolgroups=toolgroups,
 | |
|             documents=documents,
 | |
|         )
 | |
|         if stream:
 | |
|             return self._create_agent_turn_streaming(request)
 | |
|         else:
 | |
|             raise NotImplementedError("Non-streaming agent turns not yet implemented")
 | |
| 
 | |
|     async def _create_agent_turn_streaming(
 | |
|         self,
 | |
|         request: AgentTurnCreateRequest,
 | |
|     ) -> AsyncGenerator:
 | |
|         agent = await self.get_agent(request.agent_id)
 | |
|         async for event in agent.create_and_execute_turn(request):
 | |
|             yield event
 | |
| 
 | |
|     async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
 | |
|         turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
 | |
|         turn = json.loads(turn)
 | |
|         turn = Turn(**turn)
 | |
|         return turn
 | |
| 
 | |
|     async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
 | |
|         turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
 | |
|         turn = json.loads(turn)
 | |
|         turn = Turn(**turn)
 | |
|         steps = turn.steps
 | |
|         for step in steps:
 | |
|             if step.step_id == step_id:
 | |
|                 return AgentStepResponse(step=step)
 | |
|         raise ValueError(f"Provided step_id {step_id} could not be found")
 | |
| 
 | |
|     async def get_agents_session(
 | |
|         self,
 | |
|         agent_id: str,
 | |
|         session_id: str,
 | |
|         turn_ids: Optional[List[str]] = None,
 | |
|     ) -> Session:
 | |
|         session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
 | |
|         session = Session(**json.loads(session), turns=[])
 | |
|         turns = []
 | |
|         if turn_ids:
 | |
|             for turn_id in turn_ids:
 | |
|                 turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
 | |
|                 turn = json.loads(turn)
 | |
|                 turn = Turn(**turn)
 | |
|                 turns.append(turn)
 | |
|         return Session(
 | |
|             session_name=session.session_name,
 | |
|             session_id=session_id,
 | |
|             turns=turns if turns else [],
 | |
|             started_at=session.started_at,
 | |
|         )
 | |
| 
 | |
|     async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
 | |
|         await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
 | |
| 
 | |
|     async def delete_agent(self, agent_id: str) -> None:
 | |
|         await self.persistence_store.delete(f"agent:{agent_id}")
 |