mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
[Bug fix ]: Triton /infer handler incompatible with batch responses (#7337)
* migrate triton to base llm http handler * clean up triton handler.py * use transform functions for triton * add TritonConfig * get openai params for triton * use triton embedding config * test_completion_triton_generate_api * test_completion_triton_infer_api * fix TritonConfig doc string * use TritonResponseIterator * fix triton embeddings * docs triton chat usage
This commit is contained in:
parent
70a9ea99f2
commit
6107f9f3f3
11 changed files with 814 additions and 450 deletions
|
@ -5,14 +5,190 @@ import TabItem from '@theme/TabItem';
|
|||
|
||||
LiteLLM supports Embedding Models on Triton Inference Servers
|
||||
|
||||
| Property | Details |
|
||||
|-------|-------|
|
||||
| Description | NVIDIA Triton Inference Server |
|
||||
| Provider Route on LiteLLM | `triton/` |
|
||||
| Supported Operations | `/chat/completion`, `/completion`, `/embedding` |
|
||||
| Supported Triton endpoints | `/infer`, `/generate`, `/embeddings` |
|
||||
| Link to Provider Doc | [Triton Inference Server ↗](https://github.com/michaelfeil/infinity) |
|
||||
|
||||
## Usage
|
||||
## Triton `/generate` - Chat Completion
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
Use the `triton/` prefix to route to triton server
|
||||
```python
|
||||
from litellm import completion
|
||||
response = completion(
|
||||
model="triton/llama-3-8b-instruct",
|
||||
messages=[{"role": "user", "content": "who are u?"}],
|
||||
max_tokens=10,
|
||||
api_base="http://localhost:8000/generate",
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: my-triton-model
|
||||
litellm_params:
|
||||
model: triton/<your-triton-model>"
|
||||
api_base: https://your-triton-api-base/triton/generate
|
||||
```
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
3. Send Request to LiteLLM Proxy Server
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
# set base_url to your proxy server
|
||||
# set api_key to send to proxy server
|
||||
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="my-triton-model",
|
||||
messages=[{"role": "user", "content": "who are u?"}],
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="curl">
|
||||
|
||||
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data ' {
|
||||
"model": "my-triton-model",
|
||||
"messages": [{"role": "user", "content": "who are u?"}]
|
||||
}'
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Triton `/infer` - Chat Completion
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
|
||||
### Example Call
|
||||
Use the `triton/` prefix to route to triton server
|
||||
```python
|
||||
from litellm import completion
|
||||
|
||||
|
||||
response = completion(
|
||||
model="triton/llama-3-8b-instruct",
|
||||
messages=[{"role": "user", "content": "who are u?"}],
|
||||
max_tokens=10,
|
||||
api_base="http://localhost:8000/infer",
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: my-triton-model
|
||||
litellm_params:
|
||||
model: triton/<your-triton-model>"
|
||||
api_base: https://your-triton-api-base/triton/infer
|
||||
```
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
3. Send Request to LiteLLM Proxy Server
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
# set base_url to your proxy server
|
||||
# set api_key to send to proxy server
|
||||
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="my-triton-model",
|
||||
messages=[{"role": "user", "content": "who are u?"}],
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="curl">
|
||||
|
||||
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data ' {
|
||||
"model": "my-triton-model",
|
||||
"messages": [{"role": "user", "content": "who are u?"}]
|
||||
}'
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
||||
## Triton `/embeddings` - Embedding
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
Use the `triton/` prefix to route to triton server
|
||||
```python
|
||||
|
|
|
@ -1019,6 +1019,9 @@ from .llms.anthropic.experimental_pass_through.transformation import (
|
|||
from .llms.groq.stt.transformation import GroqSTTConfig
|
||||
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
||||
from .llms.triton.completion.transformation import TritonConfig
|
||||
from .llms.triton.completion.transformation import TritonGenerateConfig
|
||||
from .llms.triton.completion.transformation import TritonInferConfig
|
||||
from .llms.triton.embedding.transformation import TritonEmbeddingConfig
|
||||
from .llms.databricks.chat.transformation import DatabricksConfig
|
||||
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
||||
from .llms.predibase.chat.transformation import PredibaseConfig
|
||||
|
|
|
@ -183,4 +183,11 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
return litellm.PredibaseConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "voyage":
|
||||
return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "triton":
|
||||
if request_type == "embeddings":
|
||||
return litellm.TritonEmbeddingConfig().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
else:
|
||||
return litellm.TritonConfig().get_supported_openai_params(model=model)
|
||||
return None
|
||||
|
|
|
@ -1,327 +1,5 @@
|
|||
import json
|
||||
from typing import Any, List, Optional, Union
|
||||
"""
|
||||
Triton Completion - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.utils import Choices, EmbeddingResponse, Message, ModelResponse
|
||||
|
||||
from ...base import BaseLLM
|
||||
from ..common_utils import TritonError
|
||||
|
||||
|
||||
class TritonChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
data: dict,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
api_base: str,
|
||||
logging_obj: Any,
|
||||
api_key: Optional[str] = None,
|
||||
) -> EmbeddingResponse:
|
||||
async_handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
|
||||
)
|
||||
|
||||
response = await async_handler.post(url=api_base, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise TritonError(status_code=response.status_code, message=response.text)
|
||||
|
||||
_text_response = response.text
|
||||
|
||||
logging_obj.post_call(original_response=_text_response)
|
||||
|
||||
_json_response = response.json()
|
||||
_embedding_output = []
|
||||
|
||||
_outputs = _json_response["outputs"]
|
||||
for output in _outputs:
|
||||
_shape = output["shape"]
|
||||
_data = output["data"]
|
||||
_split_output_data = self.split_embedding_by_shape(_data, _shape)
|
||||
|
||||
for idx, embedding in enumerate(_split_output_data):
|
||||
_embedding_output.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding,
|
||||
}
|
||||
)
|
||||
|
||||
model_response.model = _json_response.get("model_name", "None")
|
||||
model_response.data = _embedding_output
|
||||
|
||||
return model_response
|
||||
|
||||
async def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: List[str],
|
||||
timeout: float,
|
||||
api_base: str,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
client=None,
|
||||
aembedding: bool = False,
|
||||
) -> EmbeddingResponse:
|
||||
data_for_triton = {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "input_text",
|
||||
"shape": [len(input)],
|
||||
"datatype": "BYTES",
|
||||
"data": input,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
|
||||
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": curl_string,
|
||||
},
|
||||
)
|
||||
|
||||
if aembedding:
|
||||
response = await self.aembedding( # type: ignore
|
||||
data=data_for_triton,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
raise Exception(
|
||||
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
||||
)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List,
|
||||
timeout: float,
|
||||
api_base: str,
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
model_response: ModelResponse,
|
||||
api_key: Optional[str] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
acompletion: bool = False,
|
||||
headers: Optional[dict] = None,
|
||||
) -> ModelResponse:
|
||||
type_of_model = ""
|
||||
optional_params.pop("stream", False)
|
||||
if api_base.endswith("generate"): ### This is a trtllm model
|
||||
data_for_triton = litellm.TritonConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers or {},
|
||||
)
|
||||
type_of_model = "trtllm"
|
||||
|
||||
elif api_base.endswith(
|
||||
"infer"
|
||||
): ### This is an infer model with a custom model on triton
|
||||
text_input = messages[0]["content"]
|
||||
data_for_triton = {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "text_input",
|
||||
"shape": [1],
|
||||
"datatype": "BYTES",
|
||||
"data": [text_input],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
for k, v in optional_params.items():
|
||||
if not (k == "stream" or k == "max_retries"):
|
||||
datatype = "INT32" if isinstance(v, int) else "BYTES"
|
||||
datatype = "FP32" if isinstance(v, float) else datatype
|
||||
data_for_triton["inputs"].append(
|
||||
{"name": k, "shape": [1], "datatype": datatype, "data": [v]}
|
||||
)
|
||||
|
||||
if "max_tokens" not in optional_params:
|
||||
data_for_triton["inputs"].append(
|
||||
{
|
||||
"name": "max_tokens",
|
||||
"shape": [1],
|
||||
"datatype": "INT32",
|
||||
"data": [20],
|
||||
}
|
||||
)
|
||||
|
||||
type_of_model = "infer"
|
||||
else: ## Unknown model type passthrough
|
||||
data_for_triton = {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "text_input",
|
||||
"shape": [1],
|
||||
"datatype": "BYTES",
|
||||
"data": [messages[0]["content"]],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
if logging_obj:
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"api_base": api_base,
|
||||
"http_client": client,
|
||||
},
|
||||
)
|
||||
|
||||
headers = litellm.TritonConfig().validate_environment(
|
||||
headers=headers or {},
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
)
|
||||
json_data_for_triton: str = json.dumps(data_for_triton)
|
||||
|
||||
if acompletion:
|
||||
return self.acompletion( # type: ignore
|
||||
model,
|
||||
json_data_for_triton,
|
||||
headers=headers,
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
stream=stream,
|
||||
model_response=model_response,
|
||||
type_of_model=type_of_model,
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
handler = _get_httpx_client()
|
||||
else:
|
||||
handler = client
|
||||
|
||||
if stream:
|
||||
return self._handle_stream( # type: ignore
|
||||
handler, api_base, json_data_for_triton, model, logging_obj
|
||||
)
|
||||
else:
|
||||
response = handler.post(
|
||||
url=api_base, data=json_data_for_triton, headers=headers
|
||||
)
|
||||
return self._handle_response(
|
||||
response, model_response, logging_obj, type_of_model=type_of_model
|
||||
)
|
||||
|
||||
async def acompletion(
|
||||
self,
|
||||
model: str,
|
||||
data_for_triton,
|
||||
api_base,
|
||||
stream,
|
||||
logging_obj,
|
||||
headers,
|
||||
model_response,
|
||||
type_of_model,
|
||||
) -> ModelResponse:
|
||||
handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
|
||||
)
|
||||
if stream:
|
||||
return self._ahandle_stream( # type: ignore
|
||||
handler, api_base, data_for_triton, model, logging_obj
|
||||
)
|
||||
else:
|
||||
response = await handler.post(
|
||||
url=api_base, data=data_for_triton, headers=headers
|
||||
)
|
||||
|
||||
return self._handle_response(
|
||||
response, model_response, logging_obj, type_of_model=type_of_model
|
||||
)
|
||||
|
||||
def _handle_stream(self, handler, api_base, data_for_triton, model, logging_obj):
|
||||
response = handler.post(
|
||||
url=api_base + "_stream", data=data_for_triton, stream=True
|
||||
)
|
||||
streamwrapper = litellm.CustomStreamWrapper(
|
||||
response.iter_lines(),
|
||||
model=model,
|
||||
custom_llm_provider="triton",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
for chunk in streamwrapper:
|
||||
yield (chunk)
|
||||
|
||||
async def _ahandle_stream(
|
||||
self, handler, api_base, data_for_triton, model, logging_obj
|
||||
):
|
||||
response = await handler.post(
|
||||
url=api_base + "_stream", data=data_for_triton, stream=True
|
||||
)
|
||||
streamwrapper = litellm.CustomStreamWrapper(
|
||||
response.aiter_lines(),
|
||||
model=model,
|
||||
custom_llm_provider="triton",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
async for chunk in streamwrapper:
|
||||
yield (chunk)
|
||||
|
||||
def _handle_response(self, response, model_response, logging_obj, type_of_model):
|
||||
if logging_obj:
|
||||
logging_obj.post_call(original_response=response)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise TritonError(status_code=response.status_code, message=response.text)
|
||||
|
||||
_json_response = response.json()
|
||||
model_response.model = _json_response.get("model_name", "None")
|
||||
if type_of_model == "trtllm":
|
||||
model_response.choices = [
|
||||
Choices(index=0, message=Message(content=_json_response["text_output"]))
|
||||
]
|
||||
elif type_of_model == "infer":
|
||||
model_response.choices = [
|
||||
Choices(
|
||||
index=0,
|
||||
message=Message(content=_json_response["outputs"][0]["data"]),
|
||||
)
|
||||
]
|
||||
else:
|
||||
model_response.choices = [
|
||||
Choices(index=0, message=Message(content=_json_response["outputs"]))
|
||||
]
|
||||
return model_response
|
||||
|
||||
@staticmethod
|
||||
def split_embedding_by_shape(
|
||||
data: List[float], shape: List[int]
|
||||
) -> List[List[float]]:
|
||||
if len(shape) != 2:
|
||||
raise ValueError("Shape must be of length 2.")
|
||||
embedding_size = shape[1]
|
||||
return [
|
||||
data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0])
|
||||
]
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
|
|
|
@ -2,44 +2,37 @@
|
|||
Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.chat.transformation import (
|
||||
BaseConfig,
|
||||
BaseLLMException,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
Choices,
|
||||
GenericStreamingChunk,
|
||||
Message,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
from ..common_utils import TritonError
|
||||
|
||||
|
||||
class TritonConfig(BaseConfig):
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
inference_params = optional_params.copy()
|
||||
stream = inference_params.pop("stream", False)
|
||||
data_for_triton: Dict[str, Any] = {
|
||||
"text_input": prompt_factory(model=model, messages=messages),
|
||||
"parameters": {
|
||||
"max_tokens": int(optional_params.get("max_tokens", 2000)),
|
||||
"bad_words": [""],
|
||||
"stop_words": [""],
|
||||
},
|
||||
"stream": bool(stream),
|
||||
}
|
||||
data_for_triton["parameters"].update(inference_params)
|
||||
return data_for_triton
|
||||
"""
|
||||
Base class for Triton configurations.
|
||||
|
||||
Handles routing between /infer and /generate triton completion llms
|
||||
"""
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||
|
@ -48,6 +41,16 @@ class TritonConfig(BaseConfig):
|
|||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Dict:
|
||||
return {"Content-Type": "application/json"}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return ["max_tokens", "max_completion_tokens"]
|
||||
|
||||
|
@ -77,16 +80,236 @@ class TritonConfig(BaseConfig):
|
|||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"response transformation done in handler.py. [TODO] Migrate here."
|
||||
api_base = litellm_params.get("api_base", "")
|
||||
llm_type = self._get_triton_llm_type(api_base)
|
||||
if llm_type == "generate":
|
||||
return TritonGenerateConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
elif llm_type == "infer":
|
||||
return TritonInferConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
return model_response
|
||||
|
||||
def validate_environment(
|
||||
def transform_request(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
api_base = litellm_params.get("api_base", "")
|
||||
llm_type = self._get_triton_llm_type(api_base)
|
||||
if llm_type == "generate":
|
||||
return TritonGenerateConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
elif llm_type == "infer":
|
||||
return TritonInferConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
return {}
|
||||
|
||||
def _get_triton_llm_type(self, api_base: str) -> Literal["generate", "infer"]:
|
||||
if api_base.endswith("/generate"):
|
||||
return "generate"
|
||||
elif api_base.endswith("/infer"):
|
||||
return "infer"
|
||||
else:
|
||||
raise ValueError(f"Invalid Triton API base: {api_base}")
|
||||
|
||||
|
||||
class TritonGenerateConfig(TritonConfig):
|
||||
"""
|
||||
Transformations for triton /generate endpoint (This is a trtllm model)
|
||||
"""
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
inference_params = optional_params.copy()
|
||||
stream = inference_params.pop("stream", False)
|
||||
data_for_triton: Dict[str, Any] = {
|
||||
"text_input": prompt_factory(model=model, messages=messages),
|
||||
"parameters": {
|
||||
"max_tokens": int(optional_params.get("max_tokens", 2000)),
|
||||
"bad_words": [""],
|
||||
"stop_words": [""],
|
||||
},
|
||||
"stream": bool(stream),
|
||||
}
|
||||
data_for_triton["parameters"].update(inference_params)
|
||||
return data_for_triton
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: Dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Dict:
|
||||
return {"Content-Type": "application/json"}
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise TritonError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
model_response.choices = [
|
||||
Choices(index=0, message=Message(content=raw_response_json["text_output"]))
|
||||
]
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
class TritonInferConfig(TritonGenerateConfig):
|
||||
"""
|
||||
Transformations for triton /infer endpoint (his is an infer model with a custom model on triton)
|
||||
"""
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
|
||||
text_input = messages[0].get("content", "")
|
||||
data_for_triton = {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "text_input",
|
||||
"shape": [1],
|
||||
"datatype": "BYTES",
|
||||
"data": [text_input],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
for k, v in optional_params.items():
|
||||
if not (k == "stream" or k == "max_retries"):
|
||||
datatype = "INT32" if isinstance(v, int) else "BYTES"
|
||||
datatype = "FP32" if isinstance(v, float) else datatype
|
||||
data_for_triton["inputs"].append(
|
||||
{"name": k, "shape": [1], "datatype": datatype, "data": [v]}
|
||||
)
|
||||
|
||||
if "max_tokens" not in optional_params:
|
||||
data_for_triton["inputs"].append(
|
||||
{
|
||||
"name": "max_tokens",
|
||||
"shape": [1],
|
||||
"datatype": "INT32",
|
||||
"data": [20],
|
||||
}
|
||||
)
|
||||
return data_for_triton
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: Dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise TritonError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
_triton_response_data = raw_response_json["outputs"][0]["data"]
|
||||
triton_response_data: Optional[str] = None
|
||||
if isinstance(_triton_response_data, list):
|
||||
triton_response_data = "".join(_triton_response_data)
|
||||
else:
|
||||
triton_response_data = _triton_response_data
|
||||
|
||||
model_response.choices = [
|
||||
Choices(
|
||||
index=0,
|
||||
message=Message(content=triton_response_data),
|
||||
)
|
||||
]
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
class TritonResponseIterator(BaseModelResponseIterator):
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
index = int(chunk.get("index", 0))
|
||||
|
||||
# set values
|
||||
text = chunk.get("text_output", "")
|
||||
finish_reason = chunk.get("stop_reason", "")
|
||||
is_finished = chunk.get("is_finished", False)
|
||||
|
||||
return GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
|
121
litellm/llms/triton/embedding/transformation.py
Normal file
121
litellm/llms/triton/embedding/transformation.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
from typing import List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import AllMessageValues, BaseLLMException
|
||||
from litellm.llms.base_llm.embedding.transformation import (
|
||||
BaseEmbeddingConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from ..common_utils import TritonError
|
||||
|
||||
|
||||
class TritonEmbeddingConfig(BaseEmbeddingConfig):
|
||||
"""
|
||||
Transformations for triton /embeddings endpoint (This is a trtllm model)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI params to Triton Embedding params
|
||||
"""
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, List[str], List[float], List[List[float]]],
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "input_text",
|
||||
"shape": [len(input)],
|
||||
"datatype": "BYTES",
|
||||
"data": input,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> EmbeddingResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise TritonError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
_embedding_output = []
|
||||
|
||||
_outputs = raw_response_json["outputs"]
|
||||
for output in _outputs:
|
||||
_shape = output["shape"]
|
||||
_data = output["data"]
|
||||
_split_output_data = self.split_embedding_by_shape(_data, _shape)
|
||||
|
||||
for idx, embedding in enumerate(_split_output_data):
|
||||
_embedding_output.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding,
|
||||
}
|
||||
)
|
||||
|
||||
model_response.model = raw_response_json.get("model_name", "None")
|
||||
model_response.data = _embedding_output
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return TritonError(
|
||||
message=error_message, status_code=status_code, headers=headers
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def split_embedding_by_shape(
|
||||
data: List[float], shape: List[int]
|
||||
) -> List[List[float]]:
|
||||
if len(shape) != 2:
|
||||
raise ValueError("Shape must be of length 2.")
|
||||
embedding_size = shape[1]
|
||||
return [
|
||||
data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0])
|
||||
]
|
|
@ -128,7 +128,6 @@ from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
|||
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
||||
from .llms.sagemaker.completion.handler import SagemakerLLM
|
||||
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
||||
from .llms.triton.completion.handler import TritonChatCompletion
|
||||
from .llms.vertex_ai import vertex_ai_non_gemini
|
||||
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
|
||||
|
@ -194,7 +193,6 @@ azure_audio_transcriptions = AzureAudioTranscription()
|
|||
huggingface = Huggingface()
|
||||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
codestral_text_completions = CodestralTextCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
bedrock_chat_completion = BedrockLLM()
|
||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||
bedrock_embedding = BedrockEmbedding()
|
||||
|
@ -2711,24 +2709,22 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
|
||||
elif custom_llm_provider == "triton":
|
||||
api_base = litellm.api_base or api_base
|
||||
model_response = triton_chat_completions.completion(
|
||||
api_base=api_base,
|
||||
timeout=timeout, # type: ignore
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
logging_obj=logging,
|
||||
stream=stream,
|
||||
acompletion=acompletion,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
response = model_response
|
||||
return response
|
||||
|
||||
elif custom_llm_provider == "cloudflare":
|
||||
api_key = (
|
||||
api_key
|
||||
|
@ -3477,9 +3473,10 @@ def embedding( # noqa: PLR0915
|
|||
raise ValueError(
|
||||
"api_base is required for triton. Please pass `api_base`"
|
||||
)
|
||||
response = triton_chat_completions.embedding( # type: ignore
|
||||
response = base_llm_http_handler.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
|
|
|
@ -2249,10 +2249,19 @@ def get_optional_params_embeddings( # noqa: PLR0915
|
|||
message="Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
||||
)
|
||||
elif custom_llm_provider == "triton":
|
||||
keys = list(non_default_params.keys())
|
||||
for k in keys:
|
||||
non_default_params.pop(k, None)
|
||||
final_params = {**non_default_params, **kwargs}
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
request_type="embeddings",
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.TritonEmbeddingConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params={},
|
||||
model=model,
|
||||
drop_params=drop_params if drop_params is not None else False,
|
||||
)
|
||||
final_params = {**optional_params, **kwargs}
|
||||
return final_params
|
||||
elif custom_llm_provider == "databricks":
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -2812,6 +2821,17 @@ def get_optional_params( # noqa: PLR0915
|
|||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "triton":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.TritonConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params if drop_params is not None else False,
|
||||
)
|
||||
|
||||
elif custom_llm_provider == "maritalk":
|
||||
## check if unsupported param passed in
|
||||
|
@ -6222,6 +6242,8 @@ class ProviderConfigManager:
|
|||
) -> BaseEmbeddingConfig:
|
||||
if litellm.LlmProviders.VOYAGE == provider:
|
||||
return litellm.VoyageEmbeddingConfig()
|
||||
elif litellm.LlmProviders.TRITON == provider:
|
||||
return litellm.TritonEmbeddingConfig()
|
||||
raise ValueError(f"Provider {provider} does not support embedding config")
|
||||
|
||||
@staticmethod
|
||||
|
|
210
tests/llm_translation/test_triton.py
Normal file
210
tests/llm_translation/test_triton.py
Normal file
|
@ -0,0 +1,210 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
|
||||
import pytest
|
||||
from litellm.llms.triton.embedding.transformation import TritonEmbeddingConfig
|
||||
import litellm
|
||||
|
||||
|
||||
def test_split_embedding_by_shape_passes():
|
||||
try:
|
||||
data = [
|
||||
{
|
||||
"shape": [2, 3],
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
}
|
||||
]
|
||||
split_output_data = TritonEmbeddingConfig.split_embedding_by_shape(
|
||||
data[0]["data"], data[0]["shape"]
|
||||
)
|
||||
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occured: {e}")
|
||||
|
||||
|
||||
def test_split_embedding_by_shape_fails_with_shape_value_error():
|
||||
data = [
|
||||
{
|
||||
"shape": [2],
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
}
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
TritonEmbeddingConfig.split_embedding_by_shape(
|
||||
data[0]["data"], data[0]["shape"]
|
||||
)
|
||||
|
||||
|
||||
def test_completion_triton_generate_api():
|
||||
try:
|
||||
mock_response = MagicMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"text_output": "I am an AI assistant",
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = litellm.completion(
|
||||
model="triton/llama-3-8b-instruct",
|
||||
messages=[{"role": "user", "content": "who are u?"}],
|
||||
max_tokens=10,
|
||||
timeout=5,
|
||||
api_base="http://localhost:8000/generate",
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_post.assert_called_once()
|
||||
|
||||
# Get the arguments passed to the post request
|
||||
print("call args", mock_post.call_args)
|
||||
call_kwargs = mock_post.call_args.kwargs # Access kwargs directly
|
||||
|
||||
# Verify URL
|
||||
assert call_kwargs["url"] == "http://localhost:8000/generate"
|
||||
|
||||
# Parse the request data from the JSON string
|
||||
request_data = json.loads(call_kwargs["data"])
|
||||
|
||||
# Verify request data
|
||||
assert request_data["text_input"] == "who are u?"
|
||||
assert request_data["parameters"]["max_tokens"] == 10
|
||||
|
||||
# Verify response
|
||||
assert response.choices[0].message.content == "I am an AI assistant"
|
||||
|
||||
except Exception as e:
|
||||
print("exception", e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_triton_infer_api():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
mock_response = MagicMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"model_name": "basketgpt",
|
||||
"model_version": "2",
|
||||
"outputs": [
|
||||
{
|
||||
"name": "text_output",
|
||||
"datatype": "BYTES",
|
||||
"shape": [1],
|
||||
"data": [
|
||||
"0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027"
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "debug_probs",
|
||||
"datatype": "FP32",
|
||||
"shape": [0],
|
||||
"data": [],
|
||||
},
|
||||
{
|
||||
"name": "debug_tokens",
|
||||
"datatype": "BYTES",
|
||||
"shape": [0],
|
||||
"data": [],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = litellm.completion(
|
||||
model="triton/llama-3-8b-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "0004900005025 0004900005026 0004900005027",
|
||||
}
|
||||
],
|
||||
api_base="http://localhost:8000/infer",
|
||||
)
|
||||
|
||||
print("litellm response", response.model_dump_json(indent=4))
|
||||
|
||||
# Verify the call was made
|
||||
mock_post.assert_called_once()
|
||||
|
||||
# Get the arguments passed to the post request
|
||||
call_kwargs = mock_post.call_args.kwargs
|
||||
|
||||
# Verify URL
|
||||
assert call_kwargs["url"] == "http://localhost:8000/infer"
|
||||
|
||||
# Parse the request data from the JSON string
|
||||
request_data = json.loads(call_kwargs["data"])
|
||||
|
||||
# Verify request matches expected Triton format
|
||||
assert request_data["inputs"][0]["name"] == "text_input"
|
||||
assert request_data["inputs"][0]["shape"] == [1]
|
||||
assert request_data["inputs"][0]["datatype"] == "BYTES"
|
||||
assert request_data["inputs"][0]["data"] == [
|
||||
"0004900005025 0004900005026 0004900005027"
|
||||
]
|
||||
|
||||
assert request_data["inputs"][1]["shape"] == [1]
|
||||
assert request_data["inputs"][1]["datatype"] == "INT32"
|
||||
assert request_data["inputs"][1]["data"] == [20]
|
||||
|
||||
# Verify response format matches expected completion format
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "0004900005024 0004900006774 0004900005024 0004900005027 0004900005026 0004900005025 0004900005027 0004900005024 0004900006774 0004900005027"
|
||||
)
|
||||
assert response.choices[0].finish_reason == "stop"
|
||||
assert response.choices[0].index == 0
|
||||
assert response.object == "chat.completion"
|
||||
|
||||
except Exception as e:
|
||||
print("exception", e)
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triton_embeddings():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.aembedding(
|
||||
model="triton/my-triton-model",
|
||||
api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
||||
# stubbed endpoint is setup to return this
|
||||
assert response.data[0]["embedding"] == [0.1, 0.2]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
|
@ -888,23 +888,6 @@ def test_voyage_embeddings():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triton_embeddings():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.aembedding(
|
||||
model="triton/my-triton-model",
|
||||
api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
||||
# stubbed endpoint is setup to return this
|
||||
assert response.data[0]["embedding"] == [0.1, 0.2]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"input", ["good morning from litellm", ["good morning from litellm"]] #
|
||||
|
|
|
@ -1,56 +0,0 @@
|
|||
import pytest
|
||||
from litellm.llms.triton.completion.handler import TritonChatCompletion
|
||||
|
||||
|
||||
def test_split_embedding_by_shape_passes():
|
||||
try:
|
||||
triton = TritonChatCompletion()
|
||||
data = [
|
||||
{
|
||||
"shape": [2, 3],
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
}
|
||||
]
|
||||
split_output_data = triton.split_embedding_by_shape(
|
||||
data[0]["data"], data[0]["shape"]
|
||||
)
|
||||
assert split_output_data == [[1, 2, 3], [4, 5, 6]]
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occured: {e}")
|
||||
|
||||
|
||||
def test_split_embedding_by_shape_fails_with_shape_value_error():
|
||||
triton = TritonChatCompletion()
|
||||
data = [
|
||||
{
|
||||
"shape": [2],
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
}
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
|
||||
|
||||
|
||||
def test_completion_triton():
|
||||
from litellm import completion
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
client = HTTPHandler()
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
response = completion(
|
||||
model="triton/llama-3-8b-instruct",
|
||||
messages=[{"role": "user", "content": "who are u?"}],
|
||||
max_tokens=10,
|
||||
timeout=5,
|
||||
client=client,
|
||||
api_base="http://localhost:8000/generate",
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
|
||||
print(mock_post.call_args.kwargs)
|
Loading…
Add table
Add a link
Reference in a new issue