forked from phoenix/litellm-mirror
31 lines
867 B
Python
31 lines
867 B
Python
import pytest
|
|
from litellm.llms.triton import TritonChatCompletion
|
|
|
|
|
|
def test_split_embedding_by_shape_passes():
|
|
try:
|
|
triton = TritonChatCompletion()
|
|
data = [
|
|
{
|
|
"shape": [2, 3],
|
|
"data": [1, 2, 3, 4, 5, 6],
|
|
}
|
|
]
|
|
split_output_data = triton.split_embedding_by_shape(
|
|
data[0]["data"], data[0]["shape"]
|
|
)
|
|
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
|
|
except Exception as e:
|
|
pytest.fail(f"An exception occured: {e}")
|
|
|
|
|
|
def test_split_embedding_by_shape_fails_with_shape_value_error():
|
|
triton = TritonChatCompletion()
|
|
data = [
|
|
{
|
|
"shape": [2],
|
|
"data": [1, 2, 3, 4, 5, 6],
|
|
}
|
|
]
|
|
with pytest.raises(ValueError):
|
|
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
|