mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 17:11:12 +00:00 
			
		
		
		
	We would like to rename the term `template` to `distribution`. To prepare for that, this is a precursor. cc @leseb
		
			
				
	
	
		
			40 lines
		
	
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			40 lines
		
	
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| from collections.abc import AsyncGenerator
 | |
| from contextvars import ContextVar
 | |
| 
 | |
| 
 | |
| def preserve_contexts_async_generator[T](
 | |
|     gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
 | |
| ) -> AsyncGenerator[T, None]:
 | |
|     """
 | |
|     Wraps an async generator to preserve context variables across iterations.
 | |
|     This is needed because we start a new asyncio event loop for each streaming request,
 | |
|     and we need to preserve the context across the event loop boundary.
 | |
|     """
 | |
|     # Capture initial context values
 | |
|     initial_context_values = {context_var.name: context_var.get() for context_var in context_vars}
 | |
| 
 | |
|     async def wrapper() -> AsyncGenerator[T, None]:
 | |
|         while True:
 | |
|             try:
 | |
|                 # Restore context values before any await
 | |
|                 for context_var in context_vars:
 | |
|                     context_var.set(initial_context_values[context_var.name])
 | |
| 
 | |
|                 item = await gen.__anext__()
 | |
| 
 | |
|                 # Update our tracked values with any changes made during this iteration
 | |
|                 for context_var in context_vars:
 | |
|                     initial_context_values[context_var.name] = context_var.get()
 | |
| 
 | |
|                 yield item
 | |
| 
 | |
|             except StopAsyncIteration:
 | |
|                 break
 | |
| 
 | |
|     return wrapper()
 |