feat: enables batch embedding support for triton

This commit is contained in:
davidschuler-8451 2024-07-16 13:31:59 -04:00
parent 2bf1f06a0e
commit 35b733f14d
2 changed files with 52 additions and 8 deletions

View file

@ -52,17 +52,25 @@ class TritonChatCompletion(BaseLLM):
logging_obj.post_call(original_response=_text_response)
_json_response = response.json()
_embedding_output = []
_outputs = _json_response["outputs"]
_output_data = _outputs[0]["data"]
_embedding_output = {
"object": "embedding",
"index": 0,
"embedding": _output_data,
}
for output in _outputs:
_shape = output["shape"]
_data = output["data"]
_split_output_data = self.split_embedding_by_shape(_data, _shape)
for idx, embedding in enumerate(_split_output_data):
_embedding_output.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding,
}
)
model_response.model = _json_response.get("model_name", "None")
model_response.data = [_embedding_output]
model_response.data = _embedding_output
return model_response
@ -83,7 +91,7 @@ class TritonChatCompletion(BaseLLM):
"inputs": [
{
"name": "input_text",
"shape": [1],
"shape": [len(input)],
"datatype": "BYTES",
"data": input,
}
@ -116,3 +124,10 @@ class TritonChatCompletion(BaseLLM):
raise Exception(
"Only async embedding supported for triton, please use litellm.aembedding() for now"
)
@staticmethod
def split_embedding_by_shape(data: list[float], shape: list[int]) -> list[list[float]]:
if len(shape) != 2:
raise ValueError("Shape must be of length 2.")
embedding_size = shape[1]
return [data[i * embedding_size: (i + 1) * embedding_size] for i in range(shape[0])]

View file

@ -0,0 +1,29 @@
import pytest
from litellm.llms.triton import TritonChatCompletion
def test_split_embedding_by_shape_passes():
try:
triton = TritonChatCompletion()
data = [
{
"shape": [2, 3],
"data": [1, 2, 3, 4, 5, 6],
}
]
split_output_data = triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
except Exception as e:
pytest.fail(f"An exception occured: {e}")
def test_split_embedding_by_shape_fails_with_shape_value_error():
triton = TritonChatCompletion()
data = [
{
"shape": [2],
"data": [1, 2, 3, 4, 5, 6],
}
]
with pytest.raises(ValueError):
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])