diff --git a/litellm/main.py b/litellm/main.py index de0716fd96..299d35468d 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -24,6 +24,7 @@ from copy import deepcopy from functools import partial from typing import ( Any, + Awaitable, Callable, Coroutine, Dict, @@ -461,19 +462,23 @@ async def acompletion( try: # Use a partial function to pass your keyword arguments - func = partial(completion, **completion_kwargs, **kwargs) + # func = partial(completion, **completion_kwargs, **kwargs) - # Add the context to the function - ctx = contextvars.copy_context() - func_with_context = partial(ctx.run, func) + # # Add the context to the function + # ctx = contextvars.copy_context() + # func_with_context = partial(ctx.run, func) - init_response = await loop.run_in_executor(None, func_with_context) + init_response = await cast( + Awaitable[Union[dict, ModelResponse, CustomStreamWrapper]], + completion(**completion_kwargs, **kwargs), + ) if isinstance(init_response, dict) or isinstance( init_response, ModelResponse ): ## CACHING SCENARIO if isinstance(init_response, dict): response = ModelResponse(**init_response) - response = init_response + else: + response = init_response elif asyncio.iscoroutine(init_response): response = await init_response else: