forked from phoenix/litellm-mirror
feat: enables batch embedding support for triton
This commit is contained in:
parent
2bf1f06a0e
commit
35b733f14d
2 changed files with 52 additions and 8 deletions
|
@ -52,17 +52,25 @@ class TritonChatCompletion(BaseLLM):
|
|||
logging_obj.post_call(original_response=_text_response)
|
||||
|
||||
_json_response = response.json()
|
||||
_embedding_output = []
|
||||
|
||||
_outputs = _json_response["outputs"]
|
||||
_output_data = _outputs[0]["data"]
|
||||
_embedding_output = {
|
||||
for output in _outputs:
|
||||
_shape = output["shape"]
|
||||
_data = output["data"]
|
||||
_split_output_data = self.split_embedding_by_shape(_data, _shape)
|
||||
|
||||
for idx, embedding in enumerate(_split_output_data):
|
||||
_embedding_output.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": _output_data,
|
||||
"index": idx,
|
||||
"embedding": embedding,
|
||||
}
|
||||
)
|
||||
|
||||
model_response.model = _json_response.get("model_name", "None")
|
||||
model_response.data = [_embedding_output]
|
||||
model_response.data = _embedding_output
|
||||
|
||||
return model_response
|
||||
|
||||
|
@ -83,7 +91,7 @@ class TritonChatCompletion(BaseLLM):
|
|||
"inputs": [
|
||||
{
|
||||
"name": "input_text",
|
||||
"shape": [1],
|
||||
"shape": [len(input)],
|
||||
"datatype": "BYTES",
|
||||
"data": input,
|
||||
}
|
||||
|
@ -116,3 +124,10 @@ class TritonChatCompletion(BaseLLM):
|
|||
raise Exception(
|
||||
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def split_embedding_by_shape(data: list[float], shape: list[int]) -> list[list[float]]:
|
||||
if len(shape) != 2:
|
||||
raise ValueError("Shape must be of length 2.")
|
||||
embedding_size = shape[1]
|
||||
return [data[i * embedding_size: (i + 1) * embedding_size] for i in range(shape[0])]
|
||||
|
|
29
litellm/tests/test_triton.py
Normal file
29
litellm/tests/test_triton.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import pytest
|
||||
from litellm.llms.triton import TritonChatCompletion
|
||||
|
||||
|
||||
def test_split_embedding_by_shape_passes():
|
||||
try:
|
||||
triton = TritonChatCompletion()
|
||||
data = [
|
||||
{
|
||||
"shape": [2, 3],
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
}
|
||||
]
|
||||
split_output_data = triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
|
||||
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occured: {e}")
|
||||
|
||||
|
||||
def test_split_embedding_by_shape_fails_with_shape_value_error():
|
||||
triton = TritonChatCompletion()
|
||||
data = [
|
||||
{
|
||||
"shape": [2],
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
}
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
|
Loading…
Add table
Add a link
Reference in a new issue