diff --git a/.circleci/config.yml b/.circleci/config.yml index 0fe72ef80..b4523e458 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -25,6 +25,7 @@ jobs: python -m pip install -r .circleci/requirements.txt pip install infisical pip install pytest + pip install pytest-asyncio pip install mypy pip install openai[datalib] pip install -Uq chromadb==0.3.29 diff --git a/litellm/main.py b/litellm/main.py index 395cae31b..73d291a39 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -60,8 +60,17 @@ async def acompletion(*args, **kwargs): func_with_context = partial(ctx.run, func) # Call the synchronous function using run_in_executor - return await loop.run_in_executor(None, func_with_context) - + response = await loop.run_in_executor(None, func_with_context) + if kwargs.get("stream", False): # return an async generator + # do not change this + # for stream = True, always return an async generator + # See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193 + return( + line + async for line in response + ) + else: + return response @client @timeout( # type: ignore diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index da6004c73..a31f20c15 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -11,7 +11,7 @@ sys.path.insert( ) # Adds the parent directory to the system path from litellm import acompletion, acreate - +@pytest.mark.asyncio async def test_get_response(): user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] @@ -22,8 +22,39 @@ async def test_get_response(): return response -response = asyncio.run(test_get_response()) -print(response) +# response = asyncio.run(test_get_response()) +# print(response) + +@pytest.mark.asyncio +async def test_get_response_streaming(): + user_message = "Hello, how are you?" + messages = [{"content": user_message, "role": "user"}] + try: + response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True) + print(type(response)) + + import inspect + + is_async_generator = inspect.isasyncgen(response) + print(is_async_generator) + + output = "" + async for chunk in response: + token = chunk["choices"][0]["delta"].get("content", "") + output += token + print(output) + + assert output is not None, "Agent output cannot be None." + assert isinstance(output, str), "Agent output needs to be of type str" + assert len(output) > 0, "Length of output needs to be greater than 0." + + except Exception as e: + pytest.fail(f"error occurred: {e}") + return response + +# response = asyncio.run(test_get_response_streaming()) +# print(response) + # async def test_get_response(): # user_message = "Hello, how are you?"