[Bug fix ]: Triton /infer handler incompatible with batch responses (#7337)

* migrate triton to base llm http handler

* clean up triton handler.py

* use transform functions for triton

* add TritonConfig

* get openai params for triton

* use triton embedding config

* test_completion_triton_generate_api

* test_completion_triton_infer_api

* fix TritonConfig doc string

* use TritonResponseIterator

* fix triton embeddings

* docs triton chat usage
This commit is contained in:
Ishaan Jaff 2024-12-20 20:59:40 -08:00 committed by GitHub
parent 70a9ea99f2
commit 6107f9f3f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 814 additions and 450 deletions

View file

@ -5,14 +5,190 @@ import TabItem from '@theme/TabItem';
LiteLLM supports Embedding Models on Triton Inference Servers LiteLLM supports Embedding Models on Triton Inference Servers
| Property | Details |
|-------|-------|
| Description | NVIDIA Triton Inference Server |
| Provider Route on LiteLLM | `triton/` |
| Supported Operations | `/chat/completion`, `/completion`, `/embedding` |
| Supported Triton endpoints | `/infer`, `/generate`, `/embeddings` |
| Link to Provider Doc | [Triton Inference Server ↗](https://github.com/michaelfeil/infinity) |
## Usage ## Triton `/generate` - Chat Completion
<Tabs>
<TabItem value="sdk" label="SDK">
Use the `triton/` prefix to route to triton server
```python
from litellm import completion
response = completion(
model="triton/llama-3-8b-instruct",
messages=[{"role": "user", "content": "who are u?"}],
max_tokens=10,
api_base="http://localhost:8000/generate",
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: my-triton-model
litellm_params:
model: triton/<your-triton-model>"
api_base: https://your-triton-api-base/triton/generate
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --detailed_debug
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
from openai import OpenAI
# set base_url to your proxy server
# set api_key to send to proxy server
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
response = client.chat.completions.create(
model="my-triton-model",
messages=[{"role": "user", "content": "who are u?"}],
max_tokens=10,
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data ' {
"model": "my-triton-model",
"messages": [{"role": "user", "content": "who are u?"}]
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Triton `/infer` - Chat Completion
<Tabs> <Tabs>
<TabItem value="sdk" label="SDK"> <TabItem value="sdk" label="SDK">
### Example Call Use the `triton/` prefix to route to triton server
```python
from litellm import completion
response = completion(
model="triton/llama-3-8b-instruct",
messages=[{"role": "user", "content": "who are u?"}],
max_tokens=10,
api_base="http://localhost:8000/infer",
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: my-triton-model
litellm_params:
model: triton/<your-triton-model>"
api_base: https://your-triton-api-base/triton/infer
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --detailed_debug
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
from openai import OpenAI
# set base_url to your proxy server
# set api_key to send to proxy server
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
response = client.chat.completions.create(
model="my-triton-model",
messages=[{"role": "user", "content": "who are u?"}],
max_tokens=10,
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data ' {
"model": "my-triton-model",
"messages": [{"role": "user", "content": "who are u?"}]
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Triton `/embeddings` - Embedding
<Tabs>
<TabItem value="sdk" label="SDK">
Use the `triton/` prefix to route to triton server Use the `triton/` prefix to route to triton server
```python ```python

View file

@ -1019,6 +1019,9 @@ from .llms.anthropic.experimental_pass_through.transformation import (
from .llms.groq.stt.transformation import GroqSTTConfig from .llms.groq.stt.transformation import GroqSTTConfig
from .llms.anthropic.completion.transformation import AnthropicTextConfig from .llms.anthropic.completion.transformation import AnthropicTextConfig
from .llms.triton.completion.transformation import TritonConfig from .llms.triton.completion.transformation import TritonConfig
from .llms.triton.completion.transformation import TritonGenerateConfig
from .llms.triton.completion.transformation import TritonInferConfig
from .llms.triton.embedding.transformation import TritonEmbeddingConfig
from .llms.databricks.chat.transformation import DatabricksConfig from .llms.databricks.chat.transformation import DatabricksConfig
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
from .llms.predibase.chat.transformation import PredibaseConfig from .llms.predibase.chat.transformation import PredibaseConfig

View file

@ -183,4 +183,11 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.PredibaseConfig().get_supported_openai_params(model=model) return litellm.PredibaseConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "voyage": elif custom_llm_provider == "voyage":
return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model) return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "triton":
if request_type == "embeddings":
return litellm.TritonEmbeddingConfig().get_supported_openai_params(
model=model
)
else:
return litellm.TritonConfig().get_supported_openai_params(model=model)
return None return None

View file

@ -1,327 +1,5 @@
import json """
from typing import Any, List, Optional, Union Triton Completion - uses `llm_http_handler.py` to make httpx requests
import litellm Request/Response transformation is handled in `transformation.py`
from litellm.llms.custom_httpx.http_handler import ( """
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.utils import Choices, EmbeddingResponse, Message, ModelResponse
from ...base import BaseLLM
from ..common_utils import TritonError
class TritonChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
async def aembedding(
self,
data: dict,
model_response: litellm.utils.EmbeddingResponse,
api_base: str,
logging_obj: Any,
api_key: Optional[str] = None,
) -> EmbeddingResponse:
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
)
response = await async_handler.post(url=api_base, data=json.dumps(data))
if response.status_code != 200:
raise TritonError(status_code=response.status_code, message=response.text)
_text_response = response.text
logging_obj.post_call(original_response=_text_response)
_json_response = response.json()
_embedding_output = []
_outputs = _json_response["outputs"]
for output in _outputs:
_shape = output["shape"]
_data = output["data"]
_split_output_data = self.split_embedding_by_shape(_data, _shape)
for idx, embedding in enumerate(_split_output_data):
_embedding_output.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding,
}
)
model_response.model = _json_response.get("model_name", "None")
model_response.data = _embedding_output
return model_response
async def embedding(
self,
model: str,
input: List[str],
timeout: float,
api_base: str,
model_response: litellm.utils.EmbeddingResponse,
logging_obj: Any,
optional_params: dict,
api_key: Optional[str] = None,
client=None,
aembedding: bool = False,
) -> EmbeddingResponse:
data_for_triton = {
"inputs": [
{
"name": "input_text",
"shape": [len(input)],
"datatype": "BYTES",
"data": input,
}
]
}
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
logging_obj.pre_call(
input="",
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": curl_string,
},
)
if aembedding:
response = await self.aembedding( # type: ignore
data=data_for_triton,
model_response=model_response,
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
)
return response
else:
raise Exception(
"Only async embedding supported for triton, please use litellm.aembedding() for now"
)
def completion(
self,
model: str,
messages: List,
timeout: float,
api_base: str,
logging_obj: Any,
optional_params: dict,
litellm_params: dict,
model_response: ModelResponse,
api_key: Optional[str] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
stream: Optional[bool] = False,
acompletion: bool = False,
headers: Optional[dict] = None,
) -> ModelResponse:
type_of_model = ""
optional_params.pop("stream", False)
if api_base.endswith("generate"): ### This is a trtllm model
data_for_triton = litellm.TritonConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers or {},
)
type_of_model = "trtllm"
elif api_base.endswith(
"infer"
): ### This is an infer model with a custom model on triton
text_input = messages[0]["content"]
data_for_triton = {
"inputs": [
{
"name": "text_input",
"shape": [1],
"datatype": "BYTES",
"data": [text_input],
}
]
}
for k, v in optional_params.items():
if not (k == "stream" or k == "max_retries"):
datatype = "INT32" if isinstance(v, int) else "BYTES"
datatype = "FP32" if isinstance(v, float) else datatype
data_for_triton["inputs"].append(
{"name": k, "shape": [1], "datatype": datatype, "data": [v]}
)
if "max_tokens" not in optional_params:
data_for_triton["inputs"].append(
{
"name": "max_tokens",
"shape": [1],
"datatype": "INT32",
"data": [20],
}
)
type_of_model = "infer"
else: ## Unknown model type passthrough
data_for_triton = {
"inputs": [
{
"name": "text_input",
"shape": [1],
"datatype": "BYTES",
"data": [messages[0]["content"]],
}
]
}
if logging_obj:
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": optional_params,
"api_base": api_base,
"http_client": client,
},
)
headers = litellm.TritonConfig().validate_environment(
headers=headers or {},
model=model,
messages=messages,
optional_params=optional_params,
api_key=api_key,
)
json_data_for_triton: str = json.dumps(data_for_triton)
if acompletion:
return self.acompletion( # type: ignore
model,
json_data_for_triton,
headers=headers,
logging_obj=logging_obj,
api_base=api_base,
stream=stream,
model_response=model_response,
type_of_model=type_of_model,
)
if client is None or not isinstance(client, HTTPHandler):
handler = _get_httpx_client()
else:
handler = client
if stream:
return self._handle_stream( # type: ignore
handler, api_base, json_data_for_triton, model, logging_obj
)
else:
response = handler.post(
url=api_base, data=json_data_for_triton, headers=headers
)
return self._handle_response(
response, model_response, logging_obj, type_of_model=type_of_model
)
async def acompletion(
self,
model: str,
data_for_triton,
api_base,
stream,
logging_obj,
headers,
model_response,
type_of_model,
) -> ModelResponse:
handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
)
if stream:
return self._ahandle_stream( # type: ignore
handler, api_base, data_for_triton, model, logging_obj
)
else:
response = await handler.post(
url=api_base, data=data_for_triton, headers=headers
)
return self._handle_response(
response, model_response, logging_obj, type_of_model=type_of_model
)
def _handle_stream(self, handler, api_base, data_for_triton, model, logging_obj):
response = handler.post(
url=api_base + "_stream", data=data_for_triton, stream=True
)
streamwrapper = litellm.CustomStreamWrapper(
response.iter_lines(),
model=model,
custom_llm_provider="triton",
logging_obj=logging_obj,
)
for chunk in streamwrapper:
yield (chunk)
async def _ahandle_stream(
self, handler, api_base, data_for_triton, model, logging_obj
):
response = await handler.post(
url=api_base + "_stream", data=data_for_triton, stream=True
)
streamwrapper = litellm.CustomStreamWrapper(
response.aiter_lines(),
model=model,
custom_llm_provider="triton",
logging_obj=logging_obj,
)
async for chunk in streamwrapper:
yield (chunk)
def _handle_response(self, response, model_response, logging_obj, type_of_model):
if logging_obj:
logging_obj.post_call(original_response=response)
if response.status_code != 200:
raise TritonError(status_code=response.status_code, message=response.text)
_json_response = response.json()
model_response.model = _json_response.get("model_name", "None")
if type_of_model == "trtllm":
model_response.choices = [
Choices(index=0, message=Message(content=_json_response["text_output"]))
]
elif type_of_model == "infer":
model_response.choices = [
Choices(
index=0,
message=Message(content=_json_response["outputs"][0]["data"]),
)
]
else:
model_response.choices = [
Choices(index=0, message=Message(content=_json_response["outputs"]))
]
return model_response
@staticmethod
def split_embedding_by_shape(
data: List[float], shape: List[int]
) -> List[List[float]]:
if len(shape) != 2:
raise ValueError("Shape must be of length 2.")
embedding_size = shape[1]
return [
data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0])
]

View file

@ -2,44 +2,37 @@
Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint. Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint.
""" """
from typing import Any, Dict, List, Optional, Union import json
from typing import Any, Dict, List, Literal, Optional, Union
from httpx import Headers, Response from httpx import Headers, Response
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.chat.transformation import ( from litellm.llms.base_llm.chat.transformation import (
BaseConfig, BaseConfig,
BaseLLMException, BaseLLMException,
LiteLLMLoggingObj, LiteLLMLoggingObj,
) )
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
Choices,
GenericStreamingChunk,
Message,
ModelResponse,
)
from ..common_utils import TritonError from ..common_utils import TritonError
class TritonConfig(BaseConfig): class TritonConfig(BaseConfig):
def transform_request( """
self, Base class for Triton configurations.
model: str,
messages: List[AllMessageValues], Handles routing between /infer and /generate triton completion llms
optional_params: dict, """
litellm_params: dict,
headers: dict,
) -> dict:
inference_params = optional_params.copy()
stream = inference_params.pop("stream", False)
data_for_triton: Dict[str, Any] = {
"text_input": prompt_factory(model=model, messages=messages),
"parameters": {
"max_tokens": int(optional_params.get("max_tokens", 2000)),
"bad_words": [""],
"stop_words": [""],
},
"stream": bool(stream),
}
data_for_triton["parameters"].update(inference_params)
return data_for_triton
def get_error_class( def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, Headers] self, error_message: str, status_code: int, headers: Union[Dict, Headers]
@ -48,6 +41,16 @@ class TritonConfig(BaseConfig):
status_code=status_code, message=error_message, headers=headers status_code=status_code, message=error_message, headers=headers
) )
def validate_environment(
self,
headers: Dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
) -> Dict:
return {"Content-Type": "application/json"}
def get_supported_openai_params(self, model: str) -> List: def get_supported_openai_params(self, model: str) -> List:
return ["max_tokens", "max_completion_tokens"] return ["max_tokens", "max_completion_tokens"]
@ -77,16 +80,236 @@ class TritonConfig(BaseConfig):
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> ModelResponse: ) -> ModelResponse:
raise NotImplementedError( api_base = litellm_params.get("api_base", "")
"response transformation done in handler.py. [TODO] Migrate here." llm_type = self._get_triton_llm_type(api_base)
if llm_type == "generate":
return TritonGenerateConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
) )
elif llm_type == "infer":
return TritonInferConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
return model_response
def validate_environment( def transform_request(
self, self,
headers: Dict,
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
api_base = litellm_params.get("api_base", "")
llm_type = self._get_triton_llm_type(api_base)
if llm_type == "generate":
return TritonGenerateConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
elif llm_type == "infer":
return TritonInferConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
return {}
def _get_triton_llm_type(self, api_base: str) -> Literal["generate", "infer"]:
if api_base.endswith("/generate"):
return "generate"
elif api_base.endswith("/infer"):
return "infer"
else:
raise ValueError(f"Invalid Triton API base: {api_base}")
class TritonGenerateConfig(TritonConfig):
"""
Transformations for triton /generate endpoint (This is a trtllm model)
"""
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
inference_params = optional_params.copy()
stream = inference_params.pop("stream", False)
data_for_triton: Dict[str, Any] = {
"text_input": prompt_factory(model=model, messages=messages),
"parameters": {
"max_tokens": int(optional_params.get("max_tokens", 2000)),
"bad_words": [""],
"stop_words": [""],
},
"stream": bool(stream),
}
data_for_triton["parameters"].update(inference_params)
return data_for_triton
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict, optional_params: Dict,
litellm_params: Dict,
encoding: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
) -> Dict: json_mode: Optional[bool] = None,
return {"Content-Type": "application/json"} ) -> ModelResponse:
try:
raw_response_json = raw_response.json()
except Exception:
raise TritonError(
message=raw_response.text, status_code=raw_response.status_code
)
model_response.choices = [
Choices(index=0, message=Message(content=raw_response_json["text_output"]))
]
return model_response
class TritonInferConfig(TritonGenerateConfig):
"""
Transformations for triton /infer endpoint (his is an infer model with a custom model on triton)
"""
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
text_input = messages[0].get("content", "")
data_for_triton = {
"inputs": [
{
"name": "text_input",
"shape": [1],
"datatype": "BYTES",
"data": [text_input],
}
]
}
for k, v in optional_params.items():
if not (k == "stream" or k == "max_retries"):
datatype = "INT32" if isinstance(v, int) else "BYTES"
datatype = "FP32" if isinstance(v, float) else datatype
data_for_triton["inputs"].append(
{"name": k, "shape": [1], "datatype": datatype, "data": [v]}
)
if "max_tokens" not in optional_params:
data_for_triton["inputs"].append(
{
"name": "max_tokens",
"shape": [1],
"datatype": "INT32",
"data": [20],
}
)
return data_for_triton
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: Dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
raw_response_json = raw_response.json()
except Exception:
raise TritonError(
message=raw_response.text, status_code=raw_response.status_code
)
_triton_response_data = raw_response_json["outputs"][0]["data"]
triton_response_data: Optional[str] = None
if isinstance(_triton_response_data, list):
triton_response_data = "".join(_triton_response_data)
else:
triton_response_data = _triton_response_data
model_response.choices = [
Choices(
index=0,
message=Message(content=triton_response_data),
)
]
return model_response
class TritonResponseIterator(BaseModelResponseIterator):
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
# set values
text = chunk.get("text_output", "")
finish_reason = chunk.get("stop_reason", "")
is_finished = chunk.get("is_finished", False)
return GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")

View file

@ -0,0 +1,121 @@
from typing import List, Optional, Union
import httpx
from litellm.llms.base_llm.chat.transformation import AllMessageValues, BaseLLMException
from litellm.llms.base_llm.embedding.transformation import (
BaseEmbeddingConfig,
LiteLLMLoggingObj,
)
from litellm.types.utils import EmbeddingResponse
from ..common_utils import TritonError
class TritonEmbeddingConfig(BaseEmbeddingConfig):
"""
Transformations for triton /embeddings endpoint (This is a trtllm model)
"""
def __init__(self) -> None:
pass
def get_supported_openai_params(self, model: str) -> list:
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI params to Triton Embedding params
"""
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
return {}
def transform_embedding_request(
self,
model: str,
input: Union[str, List[str], List[float], List[List[float]]],
optional_params: dict,
headers: dict,
) -> dict:
return {
"inputs": [
{
"name": "input_text",
"shape": [len(input)],
"datatype": "BYTES",
"data": input,
}
]
}
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> EmbeddingResponse:
try:
raw_response_json = raw_response.json()
except Exception:
raise TritonError(
message=raw_response.text, status_code=raw_response.status_code
)
_embedding_output = []
_outputs = raw_response_json["outputs"]
for output in _outputs:
_shape = output["shape"]
_data = output["data"]
_split_output_data = self.split_embedding_by_shape(_data, _shape)
for idx, embedding in enumerate(_split_output_data):
_embedding_output.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding,
}
)
model_response.model = raw_response_json.get("model_name", "None")
model_response.data = _embedding_output
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return TritonError(
message=error_message, status_code=status_code, headers=headers
)
@staticmethod
def split_embedding_by_shape(
data: List[float], shape: List[int]
) -> List[List[float]]:
if len(shape) != 2:
raise ValueError("Shape must be of length 2.")
embedding_size = shape[1]
return [
data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0])
]

View file

@ -128,7 +128,6 @@ from .llms.replicate.chat.handler import completion as replicate_chat_completion
from .llms.sagemaker.chat.handler import SagemakerChatHandler from .llms.sagemaker.chat.handler import SagemakerChatHandler
from .llms.sagemaker.completion.handler import SagemakerLLM from .llms.sagemaker.completion.handler import SagemakerLLM
from .llms.together_ai.completion.handler import TogetherAITextCompletion from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton.completion.handler import TritonChatCompletion
from .llms.vertex_ai import vertex_ai_non_gemini from .llms.vertex_ai import vertex_ai_non_gemini
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import ( from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
@ -194,7 +193,6 @@ azure_audio_transcriptions = AzureAudioTranscription()
huggingface = Huggingface() huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion() predibase_chat_completions = PredibaseChatCompletion()
codestral_text_completions = CodestralTextCompletion() codestral_text_completions = CodestralTextCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM() bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_converse_chat_completion = BedrockConverseLLM()
bedrock_embedding = BedrockEmbedding() bedrock_embedding = BedrockEmbedding()
@ -2711,24 +2709,22 @@ def completion( # type: ignore # noqa: PLR0915
elif custom_llm_provider == "triton": elif custom_llm_provider == "triton":
api_base = litellm.api_base or api_base api_base = litellm.api_base or api_base
model_response = triton_chat_completions.completion( response = base_llm_http_handler.completion(
api_base=api_base,
timeout=timeout, # type: ignore
model=model, model=model,
stream=stream,
messages=messages, messages=messages,
acompletion=acompletion,
api_base=api_base,
model_response=model_response, model_response=model_response,
optional_params=optional_params, optional_params=optional_params,
logging_obj=logging,
stream=stream,
acompletion=acompletion,
client=client,
litellm_params=litellm_params, litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
timeout=timeout,
headers=headers,
encoding=encoding,
api_key=api_key,
logging_obj=logging,
) )
## RESPONSE OBJECT
response = model_response
return response
elif custom_llm_provider == "cloudflare": elif custom_llm_provider == "cloudflare":
api_key = ( api_key = (
api_key api_key
@ -3477,9 +3473,10 @@ def embedding( # noqa: PLR0915
raise ValueError( raise ValueError(
"api_base is required for triton. Please pass `api_base`" "api_base is required for triton. Please pass `api_base`"
) )
response = triton_chat_completions.embedding( # type: ignore response = base_llm_http_handler.embedding(
model=model, model=model,
input=input, input=input,
custom_llm_provider=custom_llm_provider,
api_base=api_base, api_base=api_base,
api_key=api_key, api_key=api_key,
logging_obj=logging, logging_obj=logging,

View file

@ -2249,10 +2249,19 @@ def get_optional_params_embeddings( # noqa: PLR0915
message="Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.", message="Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
) )
elif custom_llm_provider == "triton": elif custom_llm_provider == "triton":
keys = list(non_default_params.keys()) supported_params = get_supported_openai_params(
for k in keys: model=model,
non_default_params.pop(k, None) custom_llm_provider=custom_llm_provider,
final_params = {**non_default_params, **kwargs} request_type="embeddings",
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.TritonEmbeddingConfig().map_openai_params(
non_default_params=non_default_params,
optional_params={},
model=model,
drop_params=drop_params if drop_params is not None else False,
)
final_params = {**optional_params, **kwargs}
return final_params return final_params
elif custom_llm_provider == "databricks": elif custom_llm_provider == "databricks":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -2812,6 +2821,17 @@ def get_optional_params( # noqa: PLR0915
else False else False
), ),
) )
elif custom_llm_provider == "triton":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.TritonConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params if drop_params is not None else False,
)
elif custom_llm_provider == "maritalk": elif custom_llm_provider == "maritalk":
## check if unsupported param passed in ## check if unsupported param passed in
@ -6222,6 +6242,8 @@ class ProviderConfigManager:
) -> BaseEmbeddingConfig: ) -> BaseEmbeddingConfig:
if litellm.LlmProviders.VOYAGE == provider: if litellm.LlmProviders.VOYAGE == provider:
return litellm.VoyageEmbeddingConfig() return litellm.VoyageEmbeddingConfig()
elif litellm.LlmProviders.TRITON == provider:
return litellm.TritonEmbeddingConfig()
raise ValueError(f"Provider {provider} does not support embedding config") raise ValueError(f"Provider {provider} does not support embedding config")
@staticmethod @staticmethod

View file

@ -0,0 +1,210 @@
import json
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
from unittest.mock import AsyncMock, MagicMock, patch
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
import pytest
from litellm.llms.triton.embedding.transformation import TritonEmbeddingConfig
import litellm
def test_split_embedding_by_shape_passes():
try:
data = [
{
"shape": [2, 3],
"data": [1, 2, 3, 4, 5, 6],
}
]
split_output_data = TritonEmbeddingConfig.split_embedding_by_shape(
data[0]["data"], data[0]["shape"]
)
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
except Exception as e:
pytest.fail(f"An exception occured: {e}")
def test_split_embedding_by_shape_fails_with_shape_value_error():
data = [
{
"shape": [2],
"data": [1, 2, 3, 4, 5, 6],
}
]
with pytest.raises(ValueError):
TritonEmbeddingConfig.split_embedding_by_shape(
data[0]["data"], data[0]["shape"]
)
def test_completion_triton_generate_api():
try:
mock_response = MagicMock()
def return_val():
return {
"text_output": "I am an AI assistant",
}
mock_response.json = return_val
mock_response.status_code = 200
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = litellm.completion(
model="triton/llama-3-8b-instruct",
messages=[{"role": "user", "content": "who are u?"}],
max_tokens=10,
timeout=5,
api_base="http://localhost:8000/generate",
)
# Verify the call was made
mock_post.assert_called_once()
# Get the arguments passed to the post request
print("call args", mock_post.call_args)
call_kwargs = mock_post.call_args.kwargs # Access kwargs directly
# Verify URL
assert call_kwargs["url"] == "http://localhost:8000/generate"
# Parse the request data from the JSON string
request_data = json.loads(call_kwargs["data"])
# Verify request data
assert request_data["text_input"] == "who are u?"
assert request_data["parameters"]["max_tokens"] == 10
# Verify response
assert response.choices[0].message.content == "I am an AI assistant"
except Exception as e:
print("exception", e)
import traceback
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
def test_completion_triton_infer_api():
litellm.set_verbose = True
try:
mock_response = MagicMock()
def return_val():
return {
"model_name": "basketgpt",
"model_version": "2",
"outputs": [
{
"name": "text_output",
"datatype": "BYTES",
"shape": [1],
"data": [
"0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027"
],
},
{
"name": "debug_probs",
"datatype": "FP32",
"shape": [0],
"data": [],
},
{
"name": "debug_tokens",
"datatype": "BYTES",
"shape": [0],
"data": [],
},
],
}
mock_response.json = return_val
mock_response.status_code = 200
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = litellm.completion(
model="triton/llama-3-8b-instruct",
messages=[
{
"role": "user",
"content": "0004900005025 0004900005026 0004900005027",
}
],
api_base="http://localhost:8000/infer",
)
print("litellm response", response.model_dump_json(indent=4))
# Verify the call was made
mock_post.assert_called_once()
# Get the arguments passed to the post request
call_kwargs = mock_post.call_args.kwargs
# Verify URL
assert call_kwargs["url"] == "http://localhost:8000/infer"
# Parse the request data from the JSON string
request_data = json.loads(call_kwargs["data"])
# Verify request matches expected Triton format
assert request_data["inputs"][0]["name"] == "text_input"
assert request_data["inputs"][0]["shape"] == [1]
assert request_data["inputs"][0]["datatype"] == "BYTES"
assert request_data["inputs"][0]["data"] == [
"0004900005025 0004900005026 0004900005027"
]
assert request_data["inputs"][1]["shape"] == [1]
assert request_data["inputs"][1]["datatype"] == "INT32"
assert request_data["inputs"][1]["data"] == [20]
# Verify response format matches expected completion format
assert (
response.choices[0].message.content
== "0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027"
)
assert response.choices[0].finish_reason == "stop"
assert response.choices[0].index == 0
assert response.object == "chat.completion"
except Exception as e:
print("exception", e)
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_triton_embeddings():
try:
litellm.set_verbose = True
response = await litellm.aembedding(
model="triton/my-triton-model",
api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
input=["good morning from litellm"],
)
print(f"response: {response}")
# stubbed endpoint is setup to return this
assert response.data[0]["embedding"] == [0.1, 0.2]
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -888,23 +888,6 @@ def test_voyage_embeddings():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_triton_embeddings():
try:
litellm.set_verbose = True
response = await litellm.aembedding(
model="triton/my-triton-model",
api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
input=["good morning from litellm"],
)
print(f"response: {response}")
# stubbed endpoint is setup to return this
assert response.data[0]["embedding"] == [0.1, 0.2]
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input", ["good morning from litellm", ["good morning from litellm"]] # "input", ["good morning from litellm", ["good morning from litellm"]] #

View file

@ -1,56 +0,0 @@
import pytest
from litellm.llms.triton.completion.handler import TritonChatCompletion
def test_split_embedding_by_shape_passes():
try:
triton = TritonChatCompletion()
data = [
{
"shape": [2, 3],
"data": [1, 2, 3, 4, 5, 6],
}
]
split_output_data = triton.split_embedding_by_shape(
data[0]["data"], data[0]["shape"]
)
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
except Exception as e:
pytest.fail(f"An exception occured: {e}")
def test_split_embedding_by_shape_fails_with_shape_value_error():
triton = TritonChatCompletion()
data = [
{
"shape": [2],
"data": [1, 2, 3, 4, 5, 6],
}
]
with pytest.raises(ValueError):
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
def test_completion_triton():
from litellm import completion
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, MagicMock, AsyncMock
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
response = completion(
model="triton/llama-3-8b-instruct",
messages=[{"role": "user", "content": "who are u?"}],
max_tokens=10,
timeout=5,
client=client,
api_base="http://localhost:8000/generate",
)
print(response)
except Exception as e:
print(e)
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)