forked from phoenix/litellm-mirror
feat: enables batch embedding support for triton
This commit is contained in:
parent
2bf1f06a0e
commit
35b733f14d
2 changed files with 52 additions and 8 deletions
|
@ -52,17 +52,25 @@ class TritonChatCompletion(BaseLLM):
|
||||||
logging_obj.post_call(original_response=_text_response)
|
logging_obj.post_call(original_response=_text_response)
|
||||||
|
|
||||||
_json_response = response.json()
|
_json_response = response.json()
|
||||||
|
_embedding_output = []
|
||||||
|
|
||||||
_outputs = _json_response["outputs"]
|
_outputs = _json_response["outputs"]
|
||||||
_output_data = _outputs[0]["data"]
|
for output in _outputs:
|
||||||
_embedding_output = {
|
_shape = output["shape"]
|
||||||
"object": "embedding",
|
_data = output["data"]
|
||||||
"index": 0,
|
_split_output_data = self.split_embedding_by_shape(_data, _shape)
|
||||||
"embedding": _output_data,
|
|
||||||
}
|
for idx, embedding in enumerate(_split_output_data):
|
||||||
|
_embedding_output.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
model_response.model = _json_response.get("model_name", "None")
|
model_response.model = _json_response.get("model_name", "None")
|
||||||
model_response.data = [_embedding_output]
|
model_response.data = _embedding_output
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
@ -83,7 +91,7 @@ class TritonChatCompletion(BaseLLM):
|
||||||
"inputs": [
|
"inputs": [
|
||||||
{
|
{
|
||||||
"name": "input_text",
|
"name": "input_text",
|
||||||
"shape": [1],
|
"shape": [len(input)],
|
||||||
"datatype": "BYTES",
|
"datatype": "BYTES",
|
||||||
"data": input,
|
"data": input,
|
||||||
}
|
}
|
||||||
|
@ -116,3 +124,10 @@ class TritonChatCompletion(BaseLLM):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_embedding_by_shape(data: list[float], shape: list[int]) -> list[list[float]]:
|
||||||
|
if len(shape) != 2:
|
||||||
|
raise ValueError("Shape must be of length 2.")
|
||||||
|
embedding_size = shape[1]
|
||||||
|
return [data[i * embedding_size: (i + 1) * embedding_size] for i in range(shape[0])]
|
||||||
|
|
29
litellm/tests/test_triton.py
Normal file
29
litellm/tests/test_triton.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
import pytest
|
||||||
|
from litellm.llms.triton import TritonChatCompletion
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_embedding_by_shape_passes():
|
||||||
|
try:
|
||||||
|
triton = TritonChatCompletion()
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"shape": [2, 3],
|
||||||
|
"data": [1, 2, 3, 4, 5, 6],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
split_output_data = triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
|
||||||
|
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occured: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_embedding_by_shape_fails_with_shape_value_error():
|
||||||
|
triton = TritonChatCompletion()
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"shape": [2],
|
||||||
|
"data": [1, 2, 3, 4, 5, 6],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
|
Loading…
Add table
Add a link
Reference in a new issue