refactor(test_async_fn.py): refactor for testing

This commit is contained in:
Krrish Dholakia 2023-10-24 12:46:30 -07:00
parent b7a023a82b
commit 2c371bb8d1
2 changed files with 34 additions and 34 deletions

View file

@ -11,46 +11,46 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from litellm import acompletion, acreate from litellm import acompletion, acreate
@pytest.mark.asyncio def test_async_response():
async def test_get_response(): import asyncio
user_message = "Hello, how are you?" async def test_get_response():
messages = [{"content": user_message, "role": "user"}] user_message = "Hello, how are you?"
try: messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model="gpt-3.5-turbo", messages=messages) try:
except Exception as e: response = await acompletion(model="gpt-3.5-turbo", messages=messages)
pass except Exception as e:
pass
response = asyncio.run(test_get_response()) response = asyncio.run(test_get_response())
# print(response) # print(response)
@pytest.mark.asyncio def test_get_response_streaming():
async def test_get_response_streaming(): import asyncio
user_message = "Hello, how are you?" async def test_async_call():
messages = [{"content": user_message, "role": "user"}] user_message = "Hello, how are you?"
try: messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True) try:
print(type(response)) response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
print(type(response))
import inspect import inspect
is_async_generator = inspect.isasyncgen(response) is_async_generator = inspect.isasyncgen(response)
print(is_async_generator) print(is_async_generator)
output = "" output = ""
async for chunk in response: async for chunk in response:
token = chunk["choices"][0]["delta"].get("content", "") token = chunk["choices"][0]["delta"].get("content", "")
output += token output += token
print(output) print(output)
assert output is not None, "output cannot be None." assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str" assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0." assert len(output) > 0, "Length of output needs to be greater than 0."
except Exception as e: except Exception as e:
pass pass
return response return response
asyncio.run(test_async_call())
# response = asyncio.run(test_get_response_streaming())
# print(response)

View file

@ -20,7 +20,6 @@ import aiohttp
import logging import logging
import asyncio import asyncio
from tokenizers import Tokenizer from tokenizers import Tokenizer
import pkg_resources
from dataclasses import ( from dataclasses import (
dataclass, dataclass,
field, field,
@ -875,6 +874,7 @@ def get_replicate_completion_pricing(completion_response=None, total_time=0.0):
def _select_tokenizer(model: str): def _select_tokenizer(model: str):
# cohere # cohere
import pkg_resources
if model in litellm.cohere_models: if model in litellm.cohere_models:
tokenizer = Tokenizer.from_pretrained("Cohere/command-nightly") tokenizer = Tokenizer.from_pretrained("Cohere/command-nightly")
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}