Added asynchronous iteration methods (`__aiter__` and `__anext__`) to the `DeepResearch` class to enable streaming of research report chunks. Introduced functionality to generate and split reports into smaller chunks for improved report handling.
82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
from gpt_researcher import GPTResearcher
|
|
|
|
|
|
class ReportGenerator:
|
|
def __init__(self, query: str, report_type: str):
|
|
"""
|
|
Initializes the ReportGenerator with a query and report type.
|
|
"""
|
|
self.query = query
|
|
self.report_type = report_type
|
|
self.researcher = GPTResearcher(query, report_type)
|
|
self._chunks = None # Placeholder for report chunks
|
|
self._index = 0 # Index for iteration
|
|
|
|
def __aiter__(self):
|
|
"""
|
|
Make this class asynchronously iterable.
|
|
"""
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
"""
|
|
Defines the asynchronous iteration logic.
|
|
"""
|
|
if self._chunks is None:
|
|
# If chunks are not generated yet, generate the report
|
|
self._chunks = await self._generate_report_chunks()
|
|
|
|
if self._index >= len(self._chunks):
|
|
# Stop iteration when all chunks are yielded
|
|
raise StopAsyncIteration
|
|
|
|
# Return the next chunk and increment the index
|
|
chunk = self._chunks[self._index]
|
|
self._index += 1
|
|
return chunk
|
|
|
|
async def _generate_report_chunks(self):
|
|
"""
|
|
Conducts research and generates the report in chunks.
|
|
"""
|
|
# Conduct research
|
|
research_result = await self.researcher.conduct_research()
|
|
report = await self.researcher.write_report()
|
|
|
|
# Retrieve additional information
|
|
research_context = self.researcher.get_research_context()
|
|
research_costs = self.researcher.get_costs()
|
|
research_images = self.researcher.get_research_images()
|
|
research_sources = self.researcher.get_research_sources()
|
|
|
|
# Construct the full response
|
|
full_report = {
|
|
"report": report,
|
|
"context": research_context,
|
|
"costs": research_costs,
|
|
"images": research_images,
|
|
"sources": research_sources,
|
|
}
|
|
|
|
# Split the report into smaller chunks for streaming
|
|
return self._split_into_chunks(full_report)
|
|
|
|
def _split_into_chunks(self, report):
|
|
"""
|
|
Splits a report dictionary into smaller chunks for streaming.
|
|
"""
|
|
# Convert the report dictionary into a list of key-value pairs,
|
|
# where each pair represents a chunk.
|
|
chunks = []
|
|
for key, value in report.items():
|
|
chunks.append(f"{key}: {value}")
|
|
return chunks # Return the list of chunks
|
|
|
|
def get_query_details(self):
|
|
"""
|
|
Returns details of the query and report type.
|
|
"""
|
|
return {
|
|
"query": self.query,
|
|
"report_type": self.report_type
|
|
}
|