diff --git a/src/main.py b/src/main.py index d32018f..1905ff4 100644 --- a/src/main.py +++ b/src/main.py @@ -38,8 +38,23 @@ async def get_report_endpoint(request: ReportRequest): # Call the asynchronous get_report function yield "Report generation started...\n" generator = ReportGenerator(request.query, request.report_type) - custom_logs_handler = await generator.generate_report() + custom_logs_handler = generator.init() + generator.generate_report() yield "Report generation completed successfully!\n" + index = 0 + while not generator.is_complete(): + # If there are more logs to send, yield them + if index < len(custom_logs_handler.logs): + log_entry = custom_logs_handler.logs[index] + index += 1 + yield f"{log_entry}\n" # Convert logs to string for streaming + else: + # Wait briefly to avoid aggressive looping + await asyncio.sleep(0.1) + # Stop if processing is complete and no more logs remain + if generator.researcher.is_complete: + break + except Exception as e: yield f"Error: {str(e)}" diff --git a/src/phoenix_technologies/gptresearch/deepresearch.py b/src/phoenix_technologies/gptresearch/deepresearch.py index 919811c..77ad753 100644 --- a/src/phoenix_technologies/gptresearch/deepresearch.py +++ b/src/phoenix_technologies/gptresearch/deepresearch.py @@ -21,10 +21,17 @@ class ReportGenerator: self.report_type = report_type # Initialize researcher with a custom WebSocket self.custom_logs_handler = CustomLogsHandler() + self.complete = False self.researcher = GPTResearcher(query, report_type, websocket=self.custom_logs_handler) - async def generate_report(self) -> CustomLogsHandler: + def init(self) -> CustomLogsHandler: + return self.custom_logs_handler + + def is_complete(self): + return self.complete + + async def generate_report(self) -> None: """ Conducts research and generates the report along with additional information. """ @@ -37,8 +44,7 @@ class ReportGenerator: research_costs = self.researcher.get_costs() research_images = self.researcher.get_research_images() research_sources = self.researcher.get_research_sources() - - return self.custom_logs_handler + self.complete = True def get_query_details(self): """