diff --git a/litellm/llms/triton.py b/litellm/llms/triton.py index d647c9c43..6d3bebcb3 100644 --- a/litellm/llms/triton.py +++ b/litellm/llms/triton.py @@ -52,17 +52,25 @@ class TritonChatCompletion(BaseLLM): logging_obj.post_call(original_response=_text_response) _json_response = response.json() + _embedding_output = [] _outputs = _json_response["outputs"] - _output_data = _outputs[0]["data"] - _embedding_output = { - "object": "embedding", - "index": 0, - "embedding": _output_data, - } + for output in _outputs: + _shape = output["shape"] + _data = output["data"] + _split_output_data = self.split_embedding_by_shape(_data, _shape) + + for idx, embedding in enumerate(_split_output_data): + _embedding_output.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding, + } + ) model_response.model = _json_response.get("model_name", "None") - model_response.data = [_embedding_output] + model_response.data = _embedding_output return model_response @@ -83,7 +91,7 @@ class TritonChatCompletion(BaseLLM): "inputs": [ { "name": "input_text", - "shape": [1], + "shape": [len(input)], "datatype": "BYTES", "data": input, } @@ -116,3 +124,10 @@ class TritonChatCompletion(BaseLLM): raise Exception( "Only async embedding supported for triton, please use litellm.aembedding() for now" ) + + @staticmethod + def split_embedding_by_shape(data: list[float], shape: list[int]) -> list[list[float]]: + if len(shape) != 2: + raise ValueError("Shape must be of length 2.") + embedding_size = shape[1] + return [data[i * embedding_size: (i + 1) * embedding_size] for i in range(shape[0])] diff --git a/litellm/tests/test_triton.py b/litellm/tests/test_triton.py new file mode 100644 index 000000000..cb1e1af28 --- /dev/null +++ b/litellm/tests/test_triton.py @@ -0,0 +1,29 @@ +import pytest +from litellm.llms.triton import TritonChatCompletion + + +def test_split_embedding_by_shape_passes(): + try: + triton = TritonChatCompletion() + data = [ + { + "shape": [2, 3], + "data": [1, 2, 3, 4, 5, 6], + } + ] + split_output_data = triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"]) + assert split_output_data == [[1, 2, 3], [4, 5, 6]] + except Exception as e: + pytest.fail(f"An exception occured: {e}") + + +def test_split_embedding_by_shape_fails_with_shape_value_error(): + triton = TritonChatCompletion() + data = [ + { + "shape": [2], + "data": [1, 2, 3, 4, 5, 6], + } + ] + with pytest.raises(ValueError): + triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])