diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index a1ba6cb48..788421da9 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -132,7 +132,7 @@ async def run_main(host: str, port: int): api = AgenticSystemClient(f"http://{host}:{port}") tool_definitions = [ - SearchToolDefinition(engine=SearchEngineType.brave), + SearchToolDefinition(engine=SearchEngineType.bing), WolframAlphaToolDefinition(), CodeInterpreterToolDefinition(), ] diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index f9a2d20dd..202f42a3c 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -388,19 +388,17 @@ class ChatAgent(ShieldRunnerMixin): ) ) - req = ChatCompletionRequest( - model=self.agent_config.model, - messages=input_messages, + tool_calls = [] + content = "" + stop_reason = None + async for chunk in self.inference_api.chat_completion( + self.agent_config.model, + input_messages, tools=self._get_tools(), tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, sampling_params=sampling_params, - ) - - tool_calls = [] - content = "" - stop_reason = None - async for chunk in self.inference_api.chat_completion_impl(req): + ): event = chunk.event if event.event_type == ChatCompletionResponseEventType.start: continue diff --git a/llama_toolchain/inference/meta_reference/inference.py b/llama_toolchain/inference/meta_reference/inference.py index 7eb6d5dc4..247c08f23 100644 --- a/llama_toolchain/inference/meta_reference/inference.py +++ b/llama_toolchain/inference/meta_reference/inference.py @@ -77,14 +77,6 @@ class MetaReferenceInferenceImpl(Inference): logprobs=logprobs, ) - async for chunk in self.chat_completion_impl(request): - yield chunk - - async def chat_completion_impl( - self, request: ChatCompletionRequest - ) -> AsyncIterator[ - Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse] - ]: messages = prepare_messages(request) model = resolve_model(request.model) if model is None: diff --git a/llama_toolchain/telemetry/api/api.py b/llama_toolchain/telemetry/api/api.py index 460c2fee2..2546c1ede 100644 --- a/llama_toolchain/telemetry/api/api.py +++ b/llama_toolchain/telemetry/api/api.py @@ -125,11 +125,7 @@ Event = Annotated[ class Telemetry(Protocol): @webmethod(route="/telemetry/log_event") -<<<<<<< migrate_request_wrapper async def log_event(self, event: Event) -> None: ... -======= - async def log_event(self, event: Event): ... ->>>>>>> main @webmethod(route="/telemetry/get_trace", method="GET") async def get_trace(self, trace_id: str) -> Trace: ...