Krrish Dholakia 2024-12-11 01:03:57 -08:00
parent 5d1274cb6e
commit 06074bb13b
8 changed files with 197 additions and 62 deletions

View file

@ -1,5 +1,5 @@
import pytest
from litellm.llms.triton import TritonChatCompletion
from litellm.llms.triton.completion.handler import TritonChatCompletion
def test_split_embedding_by_shape_passes():
@ -29,3 +29,28 @@ def test_split_embedding_by_shape_fails_with_shape_value_error():
]
with pytest.raises(ValueError):
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
def test_completion_triton():
from litellm import completion
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, MagicMock, AsyncMock
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
response = completion(
model="triton/llama-3-8b-instruct",
messages=[{"role": "user", "content": "who are u?"}],
max_tokens=10,
timeout=5,
client=client,
api_base="http://localhost:8000/generate",
)
print(response)
except Exception as e:
print(e)
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)