litellm/tests/local_testing/test_triton.py

31 lines
867 B
Python

import pytest
from litellm.llms.triton import TritonChatCompletion
def test_split_embedding_by_shape_passes():
try:
triton = TritonChatCompletion()
data = [
{
"shape": [2, 3],
"data": [1, 2, 3, 4, 5, 6],
}
]
split_output_data = triton.split_embedding_by_shape(
data[0]["data"], data[0]["shape"]
)
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
except Exception as e:
pytest.fail(f"An exception occured: {e}")
def test_split_embedding_by_shape_fails_with_shape_value_error():
triton = TritonChatCompletion()
data = [
{
"shape": [2],
"data": [1, 2, 3, 4, 5, 6],
}
]
with pytest.raises(ValueError):
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])