litellm-mirror/tests/local_testing/test_triton.py

56 lines
1.6 KiB
Python

import pytest
from litellm.llms.triton.completion.handler import TritonChatCompletion
def test_split_embedding_by_shape_passes():
try:
triton = TritonChatCompletion()
data = [
{
"shape": [2, 3],
"data": [1, 2, 3, 4, 5, 6],
}
]
split_output_data = triton.split_embedding_by_shape(
data[0]["data"], data[0]["shape"]
)
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
except Exception as e:
pytest.fail(f"An exception occured: {e}")
def test_split_embedding_by_shape_fails_with_shape_value_error():
triton = TritonChatCompletion()
data = [
{
"shape": [2],
"data": [1, 2, 3, 4, 5, 6],
}
]
with pytest.raises(ValueError):
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
def test_completion_triton():
from litellm import completion
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, MagicMock, AsyncMock
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
response = completion(
model="triton/llama-3-8b-instruct",
messages=[{"role": "user", "content": "who are u?"}],
max_tokens=10,
timeout=5,
client=client,
api_base="http://localhost:8000/generate",
)
print(response)
except Exception as e:
print(e)
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)