litellm-mirror/tests/llm_translation/test_triton.py
2025-02-13 15:43:42 +09:00

232 lines
7.7 KiB
Python

import json
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
from unittest.mock import AsyncMock, MagicMock, patch
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
import pytest
from litellm.llms.triton.embedding.transformation import TritonEmbeddingConfig
import litellm
def test_split_embedding_by_shape_passes():
try:
data = [
{
"shape": [2, 3],
"data": [1, 2, 3, 4, 5, 6],
}
]
split_output_data = TritonEmbeddingConfig.split_embedding_by_shape(
data[0]["data"], data[0]["shape"]
)
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
except Exception as e:
pytest.fail(f"An exception occured: {e}")
def test_split_embedding_by_shape_fails_with_shape_value_error():
data = [
{
"shape": [2],
"data": [1, 2, 3, 4, 5, 6],
}
]
with pytest.raises(ValueError):
TritonEmbeddingConfig.split_embedding_by_shape(
data[0]["data"], data[0]["shape"]
)
@pytest.mark.parametrize("stream", [True, False])
def test_completion_triton_generate_api(stream):
try:
mock_response = MagicMock()
if stream:
def mock_iter_lines():
mock_output = ''.join([
'data: {"model_name":"ensemble","model_version":"1","sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":"' + t + '"}\n\n'
for t in ["I", " am", " an", " AI", " assistant"]
])
for out in mock_output.split('\n'):
yield out
mock_response.iter_lines = mock_iter_lines
else:
def return_val():
return {
"text_output": "I am an AI assistant",
}
mock_response.json = return_val
mock_response.status_code = 200
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = litellm.completion(
model="triton/llama-3-8b-instruct",
messages=[{"role": "user", "content": "who are u?"}],
max_tokens=10,
timeout=5,
api_base="http://localhost:8000/generate",
stream=stream,
)
# Verify the call was made
mock_post.assert_called_once()
# Get the arguments passed to the post request
print("call args", mock_post.call_args)
call_kwargs = mock_post.call_args.kwargs # Access kwargs directly
# Verify URL
if stream:
assert call_kwargs["url"] == "http://localhost:8000/generate_stream"
else:
assert call_kwargs["url"] == "http://localhost:8000/generate"
# Parse the request data from the JSON string
request_data = json.loads(call_kwargs["data"])
# Verify request data
assert request_data["text_input"] == "who are u?"
assert request_data["parameters"]["max_tokens"] == 10
# Verify response
if stream:
tokens = ["I", " am", " an", " AI", " assistant", None]
idx = 0
for chunk in response:
assert chunk.choices[0].delta.content == tokens[idx]
idx += 1
assert idx == len(tokens)
else:
assert response.choices[0].message.content == "I am an AI assistant"
except Exception as e:
print("exception", e)
import traceback
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
def test_completion_triton_infer_api():
litellm.set_verbose = True
try:
mock_response = MagicMock()
def return_val():
return {
"model_name": "basketgpt",
"model_version": "2",
"outputs": [
{
"name": "text_output",
"datatype": "BYTES",
"shape": [1],
"data": [
"0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027"
],
},
{
"name": "debug_probs",
"datatype": "FP32",
"shape": [0],
"data": [],
},
{
"name": "debug_tokens",
"datatype": "BYTES",
"shape": [0],
"data": [],
},
],
}
mock_response.json = return_val
mock_response.status_code = 200
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = litellm.completion(
model="triton/llama-3-8b-instruct",
messages=[
{
"role": "user",
"content": "0004900005025 0004900005026 0004900005027",
}
],
api_base="http://localhost:8000/infer",
)
print("litellm response", response.model_dump_json(indent=4))
# Verify the call was made
mock_post.assert_called_once()
# Get the arguments passed to the post request
call_kwargs = mock_post.call_args.kwargs
# Verify URL
assert call_kwargs["url"] == "http://localhost:8000/infer"
# Parse the request data from the JSON string
request_data = json.loads(call_kwargs["data"])
# Verify request matches expected Triton format
assert request_data["inputs"][0]["name"] == "text_input"
assert request_data["inputs"][0]["shape"] == [1]
assert request_data["inputs"][0]["datatype"] == "BYTES"
assert request_data["inputs"][0]["data"] == [
"0004900005025 0004900005026 0004900005027"
]
assert request_data["inputs"][1]["shape"] == [1]
assert request_data["inputs"][1]["datatype"] == "INT32"
assert request_data["inputs"][1]["data"] == [20]
# Verify response format matches expected completion format
assert (
response.choices[0].message.content
== "0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027"
)
assert response.choices[0].finish_reason == "stop"
assert response.choices[0].index == 0
assert response.object == "chat.completion"
except Exception as e:
print("exception", e)
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_triton_embeddings():
try:
litellm.set_verbose = True
response = await litellm.aembedding(
model="triton/my-triton-model",
api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
input=["good morning from litellm"],
)
print(f"response: {response}")
# stubbed endpoint is setup to return this
assert response.data[0]["embedding"] == [0.1, 0.2]
except Exception as e:
pytest.fail(f"Error occurred: {e}")