refactor(test_async_fn.py): refactor for testing

This commit is contained in:
Krrish Dholakia 2023-10-24 12:46:30 -07:00
parent b7a023a82b
commit 2c371bb8d1
2 changed files with 34 additions and 34 deletions

View file

@ -11,8 +11,9 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from litellm import acompletion, acreate from litellm import acompletion, acreate
@pytest.mark.asyncio def test_async_response():
async def test_get_response(): import asyncio
async def test_get_response():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
@ -20,11 +21,12 @@ async def test_get_response():
except Exception as e: except Exception as e:
pass pass
response = asyncio.run(test_get_response()) response = asyncio.run(test_get_response())
# print(response) # print(response)
@pytest.mark.asyncio def test_get_response_streaming():
async def test_get_response_streaming(): import asyncio
async def test_async_call():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
@ -49,8 +51,6 @@ async def test_get_response_streaming():
except Exception as e: except Exception as e:
pass pass
return response return response
asyncio.run(test_async_call())
# response = asyncio.run(test_get_response_streaming())
# print(response)

View file

@ -20,7 +20,6 @@ import aiohttp
import logging import logging
import asyncio import asyncio
from tokenizers import Tokenizer from tokenizers import Tokenizer
import pkg_resources
from dataclasses import ( from dataclasses import (
dataclass, dataclass,
field, field,
@ -875,6 +874,7 @@ def get_replicate_completion_pricing(completion_response=None, total_time=0.0):
def _select_tokenizer(model: str): def _select_tokenizer(model: str):
# cohere # cohere
import pkg_resources
if model in litellm.cohere_models: if model in litellm.cohere_models:
tokenizer = Tokenizer.from_pretrained("Cohere/command-nightly") tokenizer = Tokenizer.from_pretrained("Cohere/command-nightly")
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}