import json import os import sys import traceback from dotenv import load_dotenv load_dotenv() import io from unittest.mock import AsyncMock, MagicMock, patch sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest import litellm import pytest from litellm.llms.triton.embedding.transformation import TritonEmbeddingConfig import litellm def test_split_embedding_by_shape_passes(): try: data = [ { "shape": [2, 3], "data": [1, 2, 3, 4, 5, 6], } ] split_output_data = TritonEmbeddingConfig.split_embedding_by_shape( data[0]["data"], data[0]["shape"] ) assert split_output_data == [[1, 2, 3], [4, 5, 6]] except Exception as e: pytest.fail(f"An exception occured: {e}") def test_split_embedding_by_shape_fails_with_shape_value_error(): data = [ { "shape": [2], "data": [1, 2, 3, 4, 5, 6], } ] with pytest.raises(ValueError): TritonEmbeddingConfig.split_embedding_by_shape( data[0]["data"], data[0]["shape"] ) @pytest.mark.parametrize("stream", [True, False]) def test_completion_triton_generate_api(stream): try: mock_response = MagicMock() if stream: def mock_iter_lines(): mock_output = ''.join([ 'data: {"model_name":"ensemble","model_version":"1","sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":"' + t + '"}\n\n' for t in ["I", " am", " an", " AI", " assistant"] ]) for out in mock_output.split('\n'): yield out mock_response.iter_lines = mock_iter_lines else: def return_val(): return { "text_output": "I am an AI assistant", } mock_response.json = return_val mock_response.status_code = 200 with patch( "litellm.llms.custom_httpx.http_handler.HTTPHandler.post", return_value=mock_response, ) as mock_post: response = litellm.completion( model="triton/llama-3-8b-instruct", messages=[{"role": "user", "content": "who are u?"}], max_tokens=10, timeout=5, api_base="http://localhost:8000/generate", stream=stream, ) # Verify the call was made mock_post.assert_called_once() # Get the arguments passed to the post request print("call args", mock_post.call_args) call_kwargs = mock_post.call_args.kwargs # Access kwargs directly # Verify URL if stream: assert call_kwargs["url"] == "http://localhost:8000/generate_stream" else: assert call_kwargs["url"] == "http://localhost:8000/generate" # Parse the request data from the JSON string request_data = json.loads(call_kwargs["data"]) # Verify request data assert request_data["text_input"] == "who are u?" assert request_data["parameters"]["max_tokens"] == 10 # Verify response if stream: tokens = ["I", " am", " an", " AI", " assistant", None] idx = 0 for chunk in response: assert chunk.choices[0].delta.content == tokens[idx] idx += 1 assert idx == len(tokens) else: assert response.choices[0].message.content == "I am an AI assistant" except Exception as e: print("exception", e) import traceback traceback.print_exc() pytest.fail(f"Error occurred: {e}") def test_completion_triton_infer_api(): litellm.set_verbose = True try: mock_response = MagicMock() def return_val(): return { "model_name": "basketgpt", "model_version": "2", "outputs": [ { "name": "text_output", "datatype": "BYTES", "shape": [1], "data": [ "0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027" ], }, { "name": "debug_probs", "datatype": "FP32", "shape": [0], "data": [], }, { "name": "debug_tokens", "datatype": "BYTES", "shape": [0], "data": [], }, ], } mock_response.json = return_val mock_response.status_code = 200 with patch( "litellm.llms.custom_httpx.http_handler.HTTPHandler.post", return_value=mock_response, ) as mock_post: response = litellm.completion( model="triton/llama-3-8b-instruct", messages=[ { "role": "user", "content": "0004900005025 0004900005026 0004900005027", } ], api_base="http://localhost:8000/infer", ) print("litellm response", response.model_dump_json(indent=4)) # Verify the call was made mock_post.assert_called_once() # Get the arguments passed to the post request call_kwargs = mock_post.call_args.kwargs # Verify URL assert call_kwargs["url"] == "http://localhost:8000/infer" # Parse the request data from the JSON string request_data = json.loads(call_kwargs["data"]) # Verify request matches expected Triton format assert request_data["inputs"][0]["name"] == "text_input" assert request_data["inputs"][0]["shape"] == [1] assert request_data["inputs"][0]["datatype"] == "BYTES" assert request_data["inputs"][0]["data"] == [ "0004900005025 0004900005026 0004900005027" ] assert request_data["inputs"][1]["shape"] == [1] assert request_data["inputs"][1]["datatype"] == "INT32" assert request_data["inputs"][1]["data"] == [20] # Verify response format matches expected completion format assert ( response.choices[0].message.content == "0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027" ) assert response.choices[0].finish_reason == "stop" assert response.choices[0].index == 0 assert response.object == "chat.completion" except Exception as e: print("exception", e) traceback.print_exc() pytest.fail(f"Error occurred: {e}") @pytest.mark.asyncio async def test_triton_embeddings(): try: litellm.set_verbose = True response = await litellm.aembedding( model="triton/my-triton-model", api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings", input=["good morning from litellm"], ) print(f"response: {response}") # stubbed endpoint is setup to return this assert response.data[0]["embedding"] == [0.1, 0.2] except Exception as e: pytest.fail(f"Error occurred: {e}") def test_triton_generate_raw_request(): from litellm.utils import return_raw_request from litellm.types.utils import CallTypes try: kwargs = { "model": "triton/llama-3-8b-instruct", "messages": [{"role": "user", "content": "who are u?"}], "api_base": "http://localhost:8000/generate", } raw_request = return_raw_request(endpoint=CallTypes.completion, kwargs=kwargs) print("raw_request", raw_request) assert raw_request is not None assert "bad_words" not in json.dumps(raw_request["raw_request_body"]) assert "stop_words" not in json.dumps(raw_request["raw_request_body"]) except Exception as e: pytest.fail(f"Error occurred: {e}")