mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
29 lines
845 B
Python
29 lines
845 B
Python
import pytest
|
|
from litellm.llms.triton import TritonChatCompletion
|
|
|
|
|
|
def test_split_embedding_by_shape_passes():
|
|
try:
|
|
triton = TritonChatCompletion()
|
|
data = [
|
|
{
|
|
"shape": [2, 3],
|
|
"data": [1, 2, 3, 4, 5, 6],
|
|
}
|
|
]
|
|
split_output_data = triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
|
|
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
|
|
except Exception as e:
|
|
pytest.fail(f"An exception occured: {e}")
|
|
|
|
|
|
def test_split_embedding_by_shape_fails_with_shape_value_error():
|
|
triton = TritonChatCompletion()
|
|
data = [
|
|
{
|
|
"shape": [2],
|
|
"data": [1, 2, 3, 4, 5, 6],
|
|
}
|
|
]
|
|
with pytest.raises(ValueError):
|
|
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
|