diff --git a/src/phoenix_technologies/gptresearch/deepresearch.py b/src/phoenix_technologies/gptresearch/deepresearch.py index fcabcf9..e3368ac 100644 --- a/src/phoenix_technologies/gptresearch/deepresearch.py +++ b/src/phoenix_technologies/gptresearch/deepresearch.py @@ -9,10 +9,35 @@ class ReportGenerator: self.query = query self.report_type = report_type self.researcher = GPTResearcher(query, report_type) + self._chunks = None # Placeholder for report chunks + self._index = 0 # Index for iteration - async def generate_report(self): + def __aiter__(self): """ - Conducts research and generates the report along with additional information. + Make this class asynchronously iterable. + """ + return self + + async def __anext__(self): + """ + Defines the asynchronous iteration logic. + """ + if self._chunks is None: + # If chunks are not generated yet, generate the report + self._chunks = await self._generate_report_chunks() + + if self._index >= len(self._chunks): + # Stop iteration when all chunks are yielded + raise StopAsyncIteration + + # Return the next chunk and increment the index + chunk = self._chunks[self._index] + self._index += 1 + return chunk + + async def _generate_report_chunks(self): + """ + Conducts research and generates the report in chunks. """ # Conduct research research_result = await self.researcher.conduct_research() @@ -24,14 +49,29 @@ class ReportGenerator: research_images = self.researcher.get_research_images() research_sources = self.researcher.get_research_sources() - return { + # Construct the full response + full_report = { "report": report, "context": research_context, "costs": research_costs, "images": research_images, - "sources": research_sources + "sources": research_sources, } + # Split the report into smaller chunks for streaming + return self._split_into_chunks(full_report) + + def _split_into_chunks(self, report): + """ + Splits a report dictionary into smaller chunks for streaming. + """ + # Convert the report dictionary into a list of key-value pairs, + # where each pair represents a chunk. + chunks = [] + for key, value in report.items(): + chunks.append(f"{key}: {value}") + return chunks # Return the list of chunks + def get_query_details(self): """ Returns details of the query and report type.