mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
[Bug fix ]: Triton /infer handler incompatible with batch responses (#7337)
* migrate triton to base llm http handler * clean up triton handler.py * use transform functions for triton * add TritonConfig * get openai params for triton * use triton embedding config * test_completion_triton_generate_api * test_completion_triton_infer_api * fix TritonConfig doc string * use TritonResponseIterator * fix triton embeddings * docs triton chat usage
This commit is contained in:
parent
70a9ea99f2
commit
6107f9f3f3
11 changed files with 814 additions and 450 deletions
|
@ -5,14 +5,190 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
LiteLLM supports Embedding Models on Triton Inference Servers
|
LiteLLM supports Embedding Models on Triton Inference Servers
|
||||||
|
|
||||||
|
| Property | Details |
|
||||||
|
|-------|-------|
|
||||||
|
| Description | NVIDIA Triton Inference Server |
|
||||||
|
| Provider Route on LiteLLM | `triton/` |
|
||||||
|
| Supported Operations | `/chat/completion`, `/completion`, `/embedding` |
|
||||||
|
| Supported Triton endpoints | `/infer`, `/generate`, `/embeddings` |
|
||||||
|
| Link to Provider Doc | [Triton Inference Server ↗](https://github.com/michaelfeil/infinity) |
|
||||||
|
|
||||||
## Usage
|
## Triton `/generate` - Chat Completion
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
Use the `triton/` prefix to route to triton server
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
response = completion(
|
||||||
|
model="triton/llama-3-8b-instruct",
|
||||||
|
messages=[{"role": "user", "content": "who are u?"}],
|
||||||
|
max_tokens=10,
|
||||||
|
api_base="http://localhost:8000/generate",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: my-triton-model
|
||||||
|
litellm_params:
|
||||||
|
model: triton/<your-triton-model>"
|
||||||
|
api_base: https://your-triton-api-base/triton/generate
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --detailed_debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Send Request to LiteLLM Proxy Server
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# set base_url to your proxy server
|
||||||
|
# set api_key to send to proxy server
|
||||||
|
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="my-triton-model",
|
||||||
|
messages=[{"role": "user", "content": "who are u?"}],
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--data ' {
|
||||||
|
"model": "my-triton-model",
|
||||||
|
"messages": [{"role": "user", "content": "who are u?"}]
|
||||||
|
}'
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## Triton `/infer` - Chat Completion
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="sdk" label="SDK">
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
|
||||||
### Example Call
|
Use the `triton/` prefix to route to triton server
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="triton/llama-3-8b-instruct",
|
||||||
|
messages=[{"role": "user", "content": "who are u?"}],
|
||||||
|
max_tokens=10,
|
||||||
|
api_base="http://localhost:8000/infer",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: my-triton-model
|
||||||
|
litellm_params:
|
||||||
|
model: triton/<your-triton-model>"
|
||||||
|
api_base: https://your-triton-api-base/triton/infer
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --detailed_debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Send Request to LiteLLM Proxy Server
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# set base_url to your proxy server
|
||||||
|
# set api_key to send to proxy server
|
||||||
|
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="my-triton-model",
|
||||||
|
messages=[{"role": "user", "content": "who are u?"}],
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--data ' {
|
||||||
|
"model": "my-triton-model",
|
||||||
|
"messages": [{"role": "user", "content": "who are u?"}]
|
||||||
|
}'
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Triton `/embeddings` - Embedding
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
Use the `triton/` prefix to route to triton server
|
Use the `triton/` prefix to route to triton server
|
||||||
```python
|
```python
|
||||||
|
|
|
@ -1019,6 +1019,9 @@ from .llms.anthropic.experimental_pass_through.transformation import (
|
||||||
from .llms.groq.stt.transformation import GroqSTTConfig
|
from .llms.groq.stt.transformation import GroqSTTConfig
|
||||||
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
||||||
from .llms.triton.completion.transformation import TritonConfig
|
from .llms.triton.completion.transformation import TritonConfig
|
||||||
|
from .llms.triton.completion.transformation import TritonGenerateConfig
|
||||||
|
from .llms.triton.completion.transformation import TritonInferConfig
|
||||||
|
from .llms.triton.embedding.transformation import TritonEmbeddingConfig
|
||||||
from .llms.databricks.chat.transformation import DatabricksConfig
|
from .llms.databricks.chat.transformation import DatabricksConfig
|
||||||
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
||||||
from .llms.predibase.chat.transformation import PredibaseConfig
|
from .llms.predibase.chat.transformation import PredibaseConfig
|
||||||
|
|
|
@ -183,4 +183,11 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
return litellm.PredibaseConfig().get_supported_openai_params(model=model)
|
return litellm.PredibaseConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "voyage":
|
elif custom_llm_provider == "voyage":
|
||||||
return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model)
|
return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model)
|
||||||
|
elif custom_llm_provider == "triton":
|
||||||
|
if request_type == "embeddings":
|
||||||
|
return litellm.TritonEmbeddingConfig().get_supported_openai_params(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return litellm.TritonConfig().get_supported_openai_params(model=model)
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -1,327 +1,5 @@
|
||||||
import json
|
"""
|
||||||
from typing import Any, List, Optional, Union
|
Triton Completion - uses `llm_http_handler.py` to make httpx requests
|
||||||
|
|
||||||
import litellm
|
Request/Response transformation is handled in `transformation.py`
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
"""
|
||||||
AsyncHTTPHandler,
|
|
||||||
HTTPHandler,
|
|
||||||
_get_httpx_client,
|
|
||||||
get_async_httpx_client,
|
|
||||||
)
|
|
||||||
from litellm.utils import Choices, EmbeddingResponse, Message, ModelResponse
|
|
||||||
|
|
||||||
from ...base import BaseLLM
|
|
||||||
from ..common_utils import TritonError
|
|
||||||
|
|
||||||
|
|
||||||
class TritonChatCompletion(BaseLLM):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
async def aembedding(
|
|
||||||
self,
|
|
||||||
data: dict,
|
|
||||||
model_response: litellm.utils.EmbeddingResponse,
|
|
||||||
api_base: str,
|
|
||||||
logging_obj: Any,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
) -> EmbeddingResponse:
|
|
||||||
async_handler = get_async_httpx_client(
|
|
||||||
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await async_handler.post(url=api_base, data=json.dumps(data))
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise TritonError(status_code=response.status_code, message=response.text)
|
|
||||||
|
|
||||||
_text_response = response.text
|
|
||||||
|
|
||||||
logging_obj.post_call(original_response=_text_response)
|
|
||||||
|
|
||||||
_json_response = response.json()
|
|
||||||
_embedding_output = []
|
|
||||||
|
|
||||||
_outputs = _json_response["outputs"]
|
|
||||||
for output in _outputs:
|
|
||||||
_shape = output["shape"]
|
|
||||||
_data = output["data"]
|
|
||||||
_split_output_data = self.split_embedding_by_shape(_data, _shape)
|
|
||||||
|
|
||||||
for idx, embedding in enumerate(_split_output_data):
|
|
||||||
_embedding_output.append(
|
|
||||||
{
|
|
||||||
"object": "embedding",
|
|
||||||
"index": idx,
|
|
||||||
"embedding": embedding,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
model_response.model = _json_response.get("model_name", "None")
|
|
||||||
model_response.data = _embedding_output
|
|
||||||
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
async def embedding(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: List[str],
|
|
||||||
timeout: float,
|
|
||||||
api_base: str,
|
|
||||||
model_response: litellm.utils.EmbeddingResponse,
|
|
||||||
logging_obj: Any,
|
|
||||||
optional_params: dict,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
client=None,
|
|
||||||
aembedding: bool = False,
|
|
||||||
) -> EmbeddingResponse:
|
|
||||||
data_for_triton = {
|
|
||||||
"inputs": [
|
|
||||||
{
|
|
||||||
"name": "input_text",
|
|
||||||
"shape": [len(input)],
|
|
||||||
"datatype": "BYTES",
|
|
||||||
"data": input,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
|
|
||||||
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input="",
|
|
||||||
api_key=None,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": optional_params,
|
|
||||||
"request_str": curl_string,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if aembedding:
|
|
||||||
response = await self.aembedding( # type: ignore
|
|
||||||
data=data_for_triton,
|
|
||||||
model_response=model_response,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
api_base=api_base,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
|
||||||
)
|
|
||||||
|
|
||||||
def completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: List,
|
|
||||||
timeout: float,
|
|
||||||
api_base: str,
|
|
||||||
logging_obj: Any,
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params: dict,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
|
||||||
stream: Optional[bool] = False,
|
|
||||||
acompletion: bool = False,
|
|
||||||
headers: Optional[dict] = None,
|
|
||||||
) -> ModelResponse:
|
|
||||||
type_of_model = ""
|
|
||||||
optional_params.pop("stream", False)
|
|
||||||
if api_base.endswith("generate"): ### This is a trtllm model
|
|
||||||
data_for_triton = litellm.TritonConfig().transform_request(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
headers=headers or {},
|
|
||||||
)
|
|
||||||
type_of_model = "trtllm"
|
|
||||||
|
|
||||||
elif api_base.endswith(
|
|
||||||
"infer"
|
|
||||||
): ### This is an infer model with a custom model on triton
|
|
||||||
text_input = messages[0]["content"]
|
|
||||||
data_for_triton = {
|
|
||||||
"inputs": [
|
|
||||||
{
|
|
||||||
"name": "text_input",
|
|
||||||
"shape": [1],
|
|
||||||
"datatype": "BYTES",
|
|
||||||
"data": [text_input],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v in optional_params.items():
|
|
||||||
if not (k == "stream" or k == "max_retries"):
|
|
||||||
datatype = "INT32" if isinstance(v, int) else "BYTES"
|
|
||||||
datatype = "FP32" if isinstance(v, float) else datatype
|
|
||||||
data_for_triton["inputs"].append(
|
|
||||||
{"name": k, "shape": [1], "datatype": datatype, "data": [v]}
|
|
||||||
)
|
|
||||||
|
|
||||||
if "max_tokens" not in optional_params:
|
|
||||||
data_for_triton["inputs"].append(
|
|
||||||
{
|
|
||||||
"name": "max_tokens",
|
|
||||||
"shape": [1],
|
|
||||||
"datatype": "INT32",
|
|
||||||
"data": [20],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
type_of_model = "infer"
|
|
||||||
else: ## Unknown model type passthrough
|
|
||||||
data_for_triton = {
|
|
||||||
"inputs": [
|
|
||||||
{
|
|
||||||
"name": "text_input",
|
|
||||||
"shape": [1],
|
|
||||||
"datatype": "BYTES",
|
|
||||||
"data": [messages[0]["content"]],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
if logging_obj:
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": optional_params,
|
|
||||||
"api_base": api_base,
|
|
||||||
"http_client": client,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
headers = litellm.TritonConfig().validate_environment(
|
|
||||||
headers=headers or {},
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
optional_params=optional_params,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
json_data_for_triton: str = json.dumps(data_for_triton)
|
|
||||||
|
|
||||||
if acompletion:
|
|
||||||
return self.acompletion( # type: ignore
|
|
||||||
model,
|
|
||||||
json_data_for_triton,
|
|
||||||
headers=headers,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
api_base=api_base,
|
|
||||||
stream=stream,
|
|
||||||
model_response=model_response,
|
|
||||||
type_of_model=type_of_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
if client is None or not isinstance(client, HTTPHandler):
|
|
||||||
handler = _get_httpx_client()
|
|
||||||
else:
|
|
||||||
handler = client
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return self._handle_stream( # type: ignore
|
|
||||||
handler, api_base, json_data_for_triton, model, logging_obj
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = handler.post(
|
|
||||||
url=api_base, data=json_data_for_triton, headers=headers
|
|
||||||
)
|
|
||||||
return self._handle_response(
|
|
||||||
response, model_response, logging_obj, type_of_model=type_of_model
|
|
||||||
)
|
|
||||||
|
|
||||||
async def acompletion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
data_for_triton,
|
|
||||||
api_base,
|
|
||||||
stream,
|
|
||||||
logging_obj,
|
|
||||||
headers,
|
|
||||||
model_response,
|
|
||||||
type_of_model,
|
|
||||||
) -> ModelResponse:
|
|
||||||
handler = get_async_httpx_client(
|
|
||||||
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
|
|
||||||
)
|
|
||||||
if stream:
|
|
||||||
return self._ahandle_stream( # type: ignore
|
|
||||||
handler, api_base, data_for_triton, model, logging_obj
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = await handler.post(
|
|
||||||
url=api_base, data=data_for_triton, headers=headers
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._handle_response(
|
|
||||||
response, model_response, logging_obj, type_of_model=type_of_model
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_stream(self, handler, api_base, data_for_triton, model, logging_obj):
|
|
||||||
response = handler.post(
|
|
||||||
url=api_base + "_stream", data=data_for_triton, stream=True
|
|
||||||
)
|
|
||||||
streamwrapper = litellm.CustomStreamWrapper(
|
|
||||||
response.iter_lines(),
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="triton",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
for chunk in streamwrapper:
|
|
||||||
yield (chunk)
|
|
||||||
|
|
||||||
async def _ahandle_stream(
|
|
||||||
self, handler, api_base, data_for_triton, model, logging_obj
|
|
||||||
):
|
|
||||||
response = await handler.post(
|
|
||||||
url=api_base + "_stream", data=data_for_triton, stream=True
|
|
||||||
)
|
|
||||||
streamwrapper = litellm.CustomStreamWrapper(
|
|
||||||
response.aiter_lines(),
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="triton",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
async for chunk in streamwrapper:
|
|
||||||
yield (chunk)
|
|
||||||
|
|
||||||
def _handle_response(self, response, model_response, logging_obj, type_of_model):
|
|
||||||
if logging_obj:
|
|
||||||
logging_obj.post_call(original_response=response)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise TritonError(status_code=response.status_code, message=response.text)
|
|
||||||
|
|
||||||
_json_response = response.json()
|
|
||||||
model_response.model = _json_response.get("model_name", "None")
|
|
||||||
if type_of_model == "trtllm":
|
|
||||||
model_response.choices = [
|
|
||||||
Choices(index=0, message=Message(content=_json_response["text_output"]))
|
|
||||||
]
|
|
||||||
elif type_of_model == "infer":
|
|
||||||
model_response.choices = [
|
|
||||||
Choices(
|
|
||||||
index=0,
|
|
||||||
message=Message(content=_json_response["outputs"][0]["data"]),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
model_response.choices = [
|
|
||||||
Choices(index=0, message=Message(content=_json_response["outputs"]))
|
|
||||||
]
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def split_embedding_by_shape(
|
|
||||||
data: List[float], shape: List[int]
|
|
||||||
) -> List[List[float]]:
|
|
||||||
if len(shape) != 2:
|
|
||||||
raise ValueError("Shape must be of length 2.")
|
|
||||||
embedding_size = shape[1]
|
|
||||||
return [
|
|
||||||
data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0])
|
|
||||||
]
|
|
||||||
|
|
|
@ -2,44 +2,37 @@
|
||||||
Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint.
|
Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
import json
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from httpx import Headers, Response
|
from httpx import Headers, Response
|
||||||
|
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
|
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
|
||||||
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
from litellm.llms.base_llm.chat.transformation import (
|
from litellm.llms.base_llm.chat.transformation import (
|
||||||
BaseConfig,
|
BaseConfig,
|
||||||
BaseLLMException,
|
BaseLLMException,
|
||||||
LiteLLMLoggingObj,
|
LiteLLMLoggingObj,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import ModelResponse
|
from litellm.types.utils import (
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
Choices,
|
||||||
|
GenericStreamingChunk,
|
||||||
|
Message,
|
||||||
|
ModelResponse,
|
||||||
|
)
|
||||||
|
|
||||||
from ..common_utils import TritonError
|
from ..common_utils import TritonError
|
||||||
|
|
||||||
|
|
||||||
class TritonConfig(BaseConfig):
|
class TritonConfig(BaseConfig):
|
||||||
def transform_request(
|
"""
|
||||||
self,
|
Base class for Triton configurations.
|
||||||
model: str,
|
|
||||||
messages: List[AllMessageValues],
|
Handles routing between /infer and /generate triton completion llms
|
||||||
optional_params: dict,
|
"""
|
||||||
litellm_params: dict,
|
|
||||||
headers: dict,
|
|
||||||
) -> dict:
|
|
||||||
inference_params = optional_params.copy()
|
|
||||||
stream = inference_params.pop("stream", False)
|
|
||||||
data_for_triton: Dict[str, Any] = {
|
|
||||||
"text_input": prompt_factory(model=model, messages=messages),
|
|
||||||
"parameters": {
|
|
||||||
"max_tokens": int(optional_params.get("max_tokens", 2000)),
|
|
||||||
"bad_words": [""],
|
|
||||||
"stop_words": [""],
|
|
||||||
},
|
|
||||||
"stream": bool(stream),
|
|
||||||
}
|
|
||||||
data_for_triton["parameters"].update(inference_params)
|
|
||||||
return data_for_triton
|
|
||||||
|
|
||||||
def get_error_class(
|
def get_error_class(
|
||||||
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||||
|
@ -48,6 +41,16 @@ class TritonConfig(BaseConfig):
|
||||||
status_code=status_code, message=error_message, headers=headers
|
status_code=status_code, message=error_message, headers=headers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: Dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> Dict:
|
||||||
|
return {"Content-Type": "application/json"}
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> List:
|
def get_supported_openai_params(self, model: str) -> List:
|
||||||
return ["max_tokens", "max_completion_tokens"]
|
return ["max_tokens", "max_completion_tokens"]
|
||||||
|
|
||||||
|
@ -77,16 +80,236 @@ class TritonConfig(BaseConfig):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
raise NotImplementedError(
|
api_base = litellm_params.get("api_base", "")
|
||||||
"response transformation done in handler.py. [TODO] Migrate here."
|
llm_type = self._get_triton_llm_type(api_base)
|
||||||
)
|
if llm_type == "generate":
|
||||||
|
return TritonGenerateConfig().transform_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=raw_response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
request_data=request_data,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
json_mode=json_mode,
|
||||||
|
)
|
||||||
|
elif llm_type == "infer":
|
||||||
|
return TritonInferConfig().transform_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=raw_response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
request_data=request_data,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
json_mode=json_mode,
|
||||||
|
)
|
||||||
|
return model_response
|
||||||
|
|
||||||
def validate_environment(
|
def transform_request(
|
||||||
self,
|
self,
|
||||||
headers: Dict,
|
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
api_base = litellm_params.get("api_base", "")
|
||||||
|
llm_type = self._get_triton_llm_type(api_base)
|
||||||
|
if llm_type == "generate":
|
||||||
|
return TritonGenerateConfig().transform_request(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
elif llm_type == "infer":
|
||||||
|
return TritonInferConfig().transform_request(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _get_triton_llm_type(self, api_base: str) -> Literal["generate", "infer"]:
|
||||||
|
if api_base.endswith("/generate"):
|
||||||
|
return "generate"
|
||||||
|
elif api_base.endswith("/infer"):
|
||||||
|
return "infer"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid Triton API base: {api_base}")
|
||||||
|
|
||||||
|
|
||||||
|
class TritonGenerateConfig(TritonConfig):
|
||||||
|
"""
|
||||||
|
Transformations for triton /generate endpoint (This is a trtllm model)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
inference_params = optional_params.copy()
|
||||||
|
stream = inference_params.pop("stream", False)
|
||||||
|
data_for_triton: Dict[str, Any] = {
|
||||||
|
"text_input": prompt_factory(model=model, messages=messages),
|
||||||
|
"parameters": {
|
||||||
|
"max_tokens": int(optional_params.get("max_tokens", 2000)),
|
||||||
|
"bad_words": [""],
|
||||||
|
"stop_words": [""],
|
||||||
|
},
|
||||||
|
"stream": bool(stream),
|
||||||
|
}
|
||||||
|
data_for_triton["parameters"].update(inference_params)
|
||||||
|
return data_for_triton
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: Dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
|
litellm_params: Dict,
|
||||||
|
encoding: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
) -> Dict:
|
json_mode: Optional[bool] = None,
|
||||||
return {"Content-Type": "application/json"}
|
) -> ModelResponse:
|
||||||
|
try:
|
||||||
|
raw_response_json = raw_response.json()
|
||||||
|
except Exception:
|
||||||
|
raise TritonError(
|
||||||
|
message=raw_response.text, status_code=raw_response.status_code
|
||||||
|
)
|
||||||
|
model_response.choices = [
|
||||||
|
Choices(index=0, message=Message(content=raw_response_json["text_output"]))
|
||||||
|
]
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
class TritonInferConfig(TritonGenerateConfig):
|
||||||
|
"""
|
||||||
|
Transformations for triton /infer endpoint (his is an infer model with a custom model on triton)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
|
||||||
|
text_input = messages[0].get("content", "")
|
||||||
|
data_for_triton = {
|
||||||
|
"inputs": [
|
||||||
|
{
|
||||||
|
"name": "text_input",
|
||||||
|
"shape": [1],
|
||||||
|
"datatype": "BYTES",
|
||||||
|
"data": [text_input],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v in optional_params.items():
|
||||||
|
if not (k == "stream" or k == "max_retries"):
|
||||||
|
datatype = "INT32" if isinstance(v, int) else "BYTES"
|
||||||
|
datatype = "FP32" if isinstance(v, float) else datatype
|
||||||
|
data_for_triton["inputs"].append(
|
||||||
|
{"name": k, "shape": [1], "datatype": datatype, "data": [v]}
|
||||||
|
)
|
||||||
|
|
||||||
|
if "max_tokens" not in optional_params:
|
||||||
|
data_for_triton["inputs"].append(
|
||||||
|
{
|
||||||
|
"name": "max_tokens",
|
||||||
|
"shape": [1],
|
||||||
|
"datatype": "INT32",
|
||||||
|
"data": [20],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return data_for_triton
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: Dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
litellm_params: Dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
try:
|
||||||
|
raw_response_json = raw_response.json()
|
||||||
|
except Exception:
|
||||||
|
raise TritonError(
|
||||||
|
message=raw_response.text, status_code=raw_response.status_code
|
||||||
|
)
|
||||||
|
|
||||||
|
_triton_response_data = raw_response_json["outputs"][0]["data"]
|
||||||
|
triton_response_data: Optional[str] = None
|
||||||
|
if isinstance(_triton_response_data, list):
|
||||||
|
triton_response_data = "".join(_triton_response_data)
|
||||||
|
else:
|
||||||
|
triton_response_data = _triton_response_data
|
||||||
|
|
||||||
|
model_response.choices = [
|
||||||
|
Choices(
|
||||||
|
index=0,
|
||||||
|
message=Message(content=triton_response_data),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
class TritonResponseIterator(BaseModelResponseIterator):
|
||||||
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
text = ""
|
||||||
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = ""
|
||||||
|
usage: Optional[ChatCompletionUsageBlock] = None
|
||||||
|
provider_specific_fields = None
|
||||||
|
index = int(chunk.get("index", 0))
|
||||||
|
|
||||||
|
# set values
|
||||||
|
text = chunk.get("text_output", "")
|
||||||
|
finish_reason = chunk.get("stop_reason", "")
|
||||||
|
is_finished = chunk.get("is_finished", False)
|
||||||
|
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
tool_use=tool_use,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
index=index,
|
||||||
|
provider_specific_fields=provider_specific_fields,
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||||
|
|
121
litellm/llms/triton/embedding/transformation.py
Normal file
121
litellm/llms/triton/embedding/transformation.py
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.chat.transformation import AllMessageValues, BaseLLMException
|
||||||
|
from litellm.llms.base_llm.embedding.transformation import (
|
||||||
|
BaseEmbeddingConfig,
|
||||||
|
LiteLLMLoggingObj,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import EmbeddingResponse
|
||||||
|
|
||||||
|
from ..common_utils import TritonError
|
||||||
|
|
||||||
|
|
||||||
|
class TritonEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
|
"""
|
||||||
|
Transformations for triton /embeddings endpoint (This is a trtllm model)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Map OpenAI params to Triton Embedding params
|
||||||
|
"""
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def transform_embedding_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: Union[str, List[str], List[float], List[List[float]]],
|
||||||
|
optional_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"inputs": [
|
||||||
|
{
|
||||||
|
"name": "input_text",
|
||||||
|
"shape": [len(input)],
|
||||||
|
"datatype": "BYTES",
|
||||||
|
"data": input,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def transform_embedding_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
request_data: dict = {},
|
||||||
|
optional_params: dict = {},
|
||||||
|
litellm_params: dict = {},
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
try:
|
||||||
|
raw_response_json = raw_response.json()
|
||||||
|
except Exception:
|
||||||
|
raise TritonError(
|
||||||
|
message=raw_response.text, status_code=raw_response.status_code
|
||||||
|
)
|
||||||
|
|
||||||
|
_embedding_output = []
|
||||||
|
|
||||||
|
_outputs = raw_response_json["outputs"]
|
||||||
|
for output in _outputs:
|
||||||
|
_shape = output["shape"]
|
||||||
|
_data = output["data"]
|
||||||
|
_split_output_data = self.split_embedding_by_shape(_data, _shape)
|
||||||
|
|
||||||
|
for idx, embedding in enumerate(_split_output_data):
|
||||||
|
_embedding_output.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.model = raw_response_json.get("model_name", "None")
|
||||||
|
model_response.data = _embedding_output
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return TritonError(
|
||||||
|
message=error_message, status_code=status_code, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_embedding_by_shape(
|
||||||
|
data: List[float], shape: List[int]
|
||||||
|
) -> List[List[float]]:
|
||||||
|
if len(shape) != 2:
|
||||||
|
raise ValueError("Shape must be of length 2.")
|
||||||
|
embedding_size = shape[1]
|
||||||
|
return [
|
||||||
|
data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0])
|
||||||
|
]
|
|
@ -128,7 +128,6 @@ from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
||||||
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
||||||
from .llms.sagemaker.completion.handler import SagemakerLLM
|
from .llms.sagemaker.completion.handler import SagemakerLLM
|
||||||
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
||||||
from .llms.triton.completion.handler import TritonChatCompletion
|
|
||||||
from .llms.vertex_ai import vertex_ai_non_gemini
|
from .llms.vertex_ai import vertex_ai_non_gemini
|
||||||
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||||
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
|
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
|
||||||
|
@ -194,7 +193,6 @@ azure_audio_transcriptions = AzureAudioTranscription()
|
||||||
huggingface = Huggingface()
|
huggingface = Huggingface()
|
||||||
predibase_chat_completions = PredibaseChatCompletion()
|
predibase_chat_completions = PredibaseChatCompletion()
|
||||||
codestral_text_completions = CodestralTextCompletion()
|
codestral_text_completions = CodestralTextCompletion()
|
||||||
triton_chat_completions = TritonChatCompletion()
|
|
||||||
bedrock_chat_completion = BedrockLLM()
|
bedrock_chat_completion = BedrockLLM()
|
||||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||||
bedrock_embedding = BedrockEmbedding()
|
bedrock_embedding = BedrockEmbedding()
|
||||||
|
@ -2711,24 +2709,22 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
|
|
||||||
elif custom_llm_provider == "triton":
|
elif custom_llm_provider == "triton":
|
||||||
api_base = litellm.api_base or api_base
|
api_base = litellm.api_base or api_base
|
||||||
model_response = triton_chat_completions.completion(
|
response = base_llm_http_handler.completion(
|
||||||
api_base=api_base,
|
|
||||||
timeout=timeout, # type: ignore
|
|
||||||
model=model,
|
model=model,
|
||||||
|
stream=stream,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
acompletion=acompletion,
|
||||||
|
api_base=api_base,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
logging_obj=logging,
|
|
||||||
stream=stream,
|
|
||||||
acompletion=acompletion,
|
|
||||||
client=client,
|
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
timeout=timeout,
|
||||||
|
headers=headers,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
response = model_response
|
|
||||||
return response
|
|
||||||
|
|
||||||
elif custom_llm_provider == "cloudflare":
|
elif custom_llm_provider == "cloudflare":
|
||||||
api_key = (
|
api_key = (
|
||||||
api_key
|
api_key
|
||||||
|
@ -3477,9 +3473,10 @@ def embedding( # noqa: PLR0915
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"api_base is required for triton. Please pass `api_base`"
|
"api_base is required for triton. Please pass `api_base`"
|
||||||
)
|
)
|
||||||
response = triton_chat_completions.embedding( # type: ignore
|
response = base_llm_http_handler.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
|
|
@ -2249,10 +2249,19 @@ def get_optional_params_embeddings( # noqa: PLR0915
|
||||||
message="Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
message="Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "triton":
|
elif custom_llm_provider == "triton":
|
||||||
keys = list(non_default_params.keys())
|
supported_params = get_supported_openai_params(
|
||||||
for k in keys:
|
model=model,
|
||||||
non_default_params.pop(k, None)
|
custom_llm_provider=custom_llm_provider,
|
||||||
final_params = {**non_default_params, **kwargs}
|
request_type="embeddings",
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.TritonEmbeddingConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params={},
|
||||||
|
model=model,
|
||||||
|
drop_params=drop_params if drop_params is not None else False,
|
||||||
|
)
|
||||||
|
final_params = {**optional_params, **kwargs}
|
||||||
return final_params
|
return final_params
|
||||||
elif custom_llm_provider == "databricks":
|
elif custom_llm_provider == "databricks":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -2812,6 +2821,17 @@ def get_optional_params( # noqa: PLR0915
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "triton":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.TritonConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=drop_params if drop_params is not None else False,
|
||||||
|
)
|
||||||
|
|
||||||
elif custom_llm_provider == "maritalk":
|
elif custom_llm_provider == "maritalk":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
|
@ -6222,6 +6242,8 @@ class ProviderConfigManager:
|
||||||
) -> BaseEmbeddingConfig:
|
) -> BaseEmbeddingConfig:
|
||||||
if litellm.LlmProviders.VOYAGE == provider:
|
if litellm.LlmProviders.VOYAGE == provider:
|
||||||
return litellm.VoyageEmbeddingConfig()
|
return litellm.VoyageEmbeddingConfig()
|
||||||
|
elif litellm.LlmProviders.TRITON == provider:
|
||||||
|
return litellm.TritonEmbeddingConfig()
|
||||||
raise ValueError(f"Provider {provider} does not support embedding config")
|
raise ValueError(f"Provider {provider} does not support embedding config")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
210
tests/llm_translation/test_triton.py
Normal file
210
tests/llm_translation/test_triton.py
Normal file
|
@ -0,0 +1,210 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import io
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from litellm.llms.triton.embedding.transformation import TritonEmbeddingConfig
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_embedding_by_shape_passes():
|
||||||
|
try:
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"shape": [2, 3],
|
||||||
|
"data": [1, 2, 3, 4, 5, 6],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
split_output_data = TritonEmbeddingConfig.split_embedding_by_shape(
|
||||||
|
data[0]["data"], data[0]["shape"]
|
||||||
|
)
|
||||||
|
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occured: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_embedding_by_shape_fails_with_shape_value_error():
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"shape": [2],
|
||||||
|
"data": [1, 2, 3, 4, 5, 6],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
TritonEmbeddingConfig.split_embedding_by_shape(
|
||||||
|
data[0]["data"], data[0]["shape"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_triton_generate_api():
|
||||||
|
try:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"text_output": "I am an AI assistant",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="triton/llama-3-8b-instruct",
|
||||||
|
messages=[{"role": "user", "content": "who are u?"}],
|
||||||
|
max_tokens=10,
|
||||||
|
timeout=5,
|
||||||
|
api_base="http://localhost:8000/generate",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the call was made
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
|
||||||
|
# Get the arguments passed to the post request
|
||||||
|
print("call args", mock_post.call_args)
|
||||||
|
call_kwargs = mock_post.call_args.kwargs # Access kwargs directly
|
||||||
|
|
||||||
|
# Verify URL
|
||||||
|
assert call_kwargs["url"] == "http://localhost:8000/generate"
|
||||||
|
|
||||||
|
# Parse the request data from the JSON string
|
||||||
|
request_data = json.loads(call_kwargs["data"])
|
||||||
|
|
||||||
|
# Verify request data
|
||||||
|
assert request_data["text_input"] == "who are u?"
|
||||||
|
assert request_data["parameters"]["max_tokens"] == 10
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
assert response.choices[0].message.content == "I am an AI assistant"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("exception", e)
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_triton_infer_api():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
try:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"model_name": "basketgpt",
|
||||||
|
"model_version": "2",
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "text_output",
|
||||||
|
"datatype": "BYTES",
|
||||||
|
"shape": [1],
|
||||||
|
"data": [
|
||||||
|
"0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027"
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "debug_probs",
|
||||||
|
"datatype": "FP32",
|
||||||
|
"shape": [0],
|
||||||
|
"data": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "debug_tokens",
|
||||||
|
"datatype": "BYTES",
|
||||||
|
"shape": [0],
|
||||||
|
"data": [],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="triton/llama-3-8b-instruct",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "0004900005025 0004900005026 0004900005027",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
api_base="http://localhost:8000/infer",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("litellm response", response.model_dump_json(indent=4))
|
||||||
|
|
||||||
|
# Verify the call was made
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
|
||||||
|
# Get the arguments passed to the post request
|
||||||
|
call_kwargs = mock_post.call_args.kwargs
|
||||||
|
|
||||||
|
# Verify URL
|
||||||
|
assert call_kwargs["url"] == "http://localhost:8000/infer"
|
||||||
|
|
||||||
|
# Parse the request data from the JSON string
|
||||||
|
request_data = json.loads(call_kwargs["data"])
|
||||||
|
|
||||||
|
# Verify request matches expected Triton format
|
||||||
|
assert request_data["inputs"][0]["name"] == "text_input"
|
||||||
|
assert request_data["inputs"][0]["shape"] == [1]
|
||||||
|
assert request_data["inputs"][0]["datatype"] == "BYTES"
|
||||||
|
assert request_data["inputs"][0]["data"] == [
|
||||||
|
"0004900005025 0004900005026 0004900005027"
|
||||||
|
]
|
||||||
|
|
||||||
|
assert request_data["inputs"][1]["shape"] == [1]
|
||||||
|
assert request_data["inputs"][1]["datatype"] == "INT32"
|
||||||
|
assert request_data["inputs"][1]["data"] == [20]
|
||||||
|
|
||||||
|
# Verify response format matches expected completion format
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== "0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027"
|
||||||
|
)
|
||||||
|
assert response.choices[0].finish_reason == "stop"
|
||||||
|
assert response.choices[0].index == 0
|
||||||
|
assert response.object == "chat.completion"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("exception", e)
|
||||||
|
traceback.print_exc()
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_triton_embeddings():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="triton/my-triton-model",
|
||||||
|
api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
# stubbed endpoint is setup to return this
|
||||||
|
assert response.data[0]["embedding"] == [0.1, 0.2]
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
|
@ -888,23 +888,6 @@ def test_voyage_embeddings():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_triton_embeddings():
|
|
||||||
try:
|
|
||||||
litellm.set_verbose = True
|
|
||||||
response = await litellm.aembedding(
|
|
||||||
model="triton/my-triton-model",
|
|
||||||
api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
|
|
||||||
input=["good morning from litellm"],
|
|
||||||
)
|
|
||||||
print(f"response: {response}")
|
|
||||||
|
|
||||||
# stubbed endpoint is setup to return this
|
|
||||||
assert response.data[0]["embedding"] == [0.1, 0.2]
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"input", ["good morning from litellm", ["good morning from litellm"]] #
|
"input", ["good morning from litellm", ["good morning from litellm"]] #
|
||||||
|
|
|
@ -1,56 +0,0 @@
|
||||||
import pytest
|
|
||||||
from litellm.llms.triton.completion.handler import TritonChatCompletion
|
|
||||||
|
|
||||||
|
|
||||||
def test_split_embedding_by_shape_passes():
|
|
||||||
try:
|
|
||||||
triton = TritonChatCompletion()
|
|
||||||
data = [
|
|
||||||
{
|
|
||||||
"shape": [2, 3],
|
|
||||||
"data": [1, 2, 3, 4, 5, 6],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
split_output_data = triton.split_embedding_by_shape(
|
|
||||||
data[0]["data"], data[0]["shape"]
|
|
||||||
)
|
|
||||||
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"An exception occured: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_split_embedding_by_shape_fails_with_shape_value_error():
|
|
||||||
triton = TritonChatCompletion()
|
|
||||||
data = [
|
|
||||||
{
|
|
||||||
"shape": [2],
|
|
||||||
"data": [1, 2, 3, 4, 5, 6],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_completion_triton():
|
|
||||||
from litellm import completion
|
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
|
||||||
from unittest.mock import patch, MagicMock, AsyncMock
|
|
||||||
|
|
||||||
client = HTTPHandler()
|
|
||||||
with patch.object(client, "post") as mock_post:
|
|
||||||
try:
|
|
||||||
response = completion(
|
|
||||||
model="triton/llama-3-8b-instruct",
|
|
||||||
messages=[{"role": "user", "content": "who are u?"}],
|
|
||||||
max_tokens=10,
|
|
||||||
timeout=5,
|
|
||||||
client=client,
|
|
||||||
api_base="http://localhost:8000/generate",
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
mock_post.assert_called_once()
|
|
||||||
|
|
||||||
print(mock_post.call_args.kwargs)
|
|
Loading…
Add table
Add a link
Reference in a new issue