diff --git a/src/main.py b/src/main.py index 95db1a8..3bbc6cf 100644 --- a/src/main.py +++ b/src/main.py @@ -36,9 +36,9 @@ async def get_report_endpoint(request: ReportRequest): async def generate_report(): try: # Call the asynchronous get_report function - report_generator = ReportGenerator(request.query, request.report_type) - async for chunk in report_generator: - yield chunk + generator = ReportGenerator(request.query, request.report_type) + async for log in generator: + yield log except Exception as e: yield f"Error: {str(e)}" diff --git a/src/phoenix_technologies/gptresearch/deepresearch.py b/src/phoenix_technologies/gptresearch/deepresearch.py index e693f45..add36e3 100644 --- a/src/phoenix_technologies/gptresearch/deepresearch.py +++ b/src/phoenix_technologies/gptresearch/deepresearch.py @@ -1,15 +1,18 @@ from gpt_researcher import GPTResearcher -from typing import Dict, Any, AsyncGenerator +from typing import Dict, Any, AsyncGenerator, Coroutine class CustomLogsHandler: """A custom Logs handler class to handle JSON data.""" - def __init__(self): - self.logs = [] # Initialize logs to store data + def __init__(self, logs=None): + if logs is None: + logs = [] + self.logs = logs # Initialize logs to store data - async def send_json(self, data: Dict[str, Any]) -> AsyncGenerator[str, Any]: + async def send_json(self, data: Dict[str, Any]) -> None: """Send JSON data and log it.""" - yield f"My custom Log: {data}" + self.logs.append(data) # Append data to logs + print(f"My custom Log: {data}") # For demonstration, print the log class ReportGenerator: def __init__(self, query: str, report_type: str): @@ -19,7 +22,8 @@ class ReportGenerator: self.query = query self.report_type = report_type # Initialize researcher with a custom WebSocket - self.custom_logs_handler = CustomLogsHandler() + self.logs = [] + self.custom_logs_handler = CustomLogsHandler(self.logs) self.researcher = GPTResearcher(query, report_type, websocket=self.custom_logs_handler) @@ -47,3 +51,18 @@ class ReportGenerator: "query": self.query, "report_type": self.report_type } + + # Implementing the asynchronous iterator protocol + def __aiter__(self): + """Initialize and return the asynchronous iterator.""" + self._log_index = 0 # Iterator index + return self + + async def __anext__(self): + """Return the next log asynchronously.""" + if self._log_index < len(self.logs): + log = self.logs[self._log_index] + self._log_index += 1 + return log # Return the next log + else: + raise StopAsyncIteration # Stop when logs are exhausted