mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
build: Squashed commit of https://github.com/BerriAI/litellm/pull/7170
Closes https://github.com/BerriAI/litellm/pull/7170
This commit is contained in:
parent
5d1274cb6e
commit
06074bb13b
8 changed files with 197 additions and 62 deletions
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
from litellm.llms.triton import TritonChatCompletion
|
||||
from litellm.llms.triton.completion.handler import TritonChatCompletion
|
||||
|
||||
|
||||
def test_split_embedding_by_shape_passes():
|
||||
|
@ -29,3 +29,28 @@ def test_split_embedding_by_shape_fails_with_shape_value_error():
|
|||
]
|
||||
with pytest.raises(ValueError):
|
||||
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
|
||||
|
||||
|
||||
def test_completion_triton():
|
||||
from litellm import completion
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
client = HTTPHandler()
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
response = completion(
|
||||
model="triton/llama-3-8b-instruct",
|
||||
messages=[{"role": "user", "content": "who are u?"}],
|
||||
max_tokens=10,
|
||||
timeout=5,
|
||||
client=client,
|
||||
api_base="http://localhost:8000/generate",
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
|
||||
print(mock_post.call_args.kwargs)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue