forked from phoenix/litellm-mirror
Merge pull request #3577 from BerriAI/litellm_add_triton_server
[Feat] Add Triton Embeddings to LiteLLM
This commit is contained in:
commit
b09075da53
8 changed files with 265 additions and 0 deletions
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
|
@ -0,0 +1,95 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Triton Inference Server
|
||||
|
||||
LiteLLM supports Embedding Models on Triton Inference Servers
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
|
||||
### Example Call
|
||||
|
||||
Use the `triton/` prefix to route to triton server
|
||||
```python
|
||||
from litellm import embedding
|
||||
import os
|
||||
|
||||
response = await litellm.aembedding(
|
||||
model="triton/<your-triton-model>",
|
||||
api_base="https://your-triton-api-base/triton/embeddings", # /embeddings endpoint you want litellm to call on your server
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: my-triton-model
|
||||
litellm_params:
|
||||
model: triton/<your-triton-model>"
|
||||
api_base: https://your-triton-api-base/triton/embeddings
|
||||
```
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
3. Send Request to LiteLLM Proxy Server
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
# set base_url to your proxy server
|
||||
# set api_key to send to proxy server
|
||||
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = client.embeddings.create(
|
||||
input=["hello from litellm"],
|
||||
model="my-triton-model"
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="curl">
|
||||
|
||||
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/embeddings' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data ' {
|
||||
"model": "my-triton-model",
|
||||
"input": ["write a litellm poem"]
|
||||
}'
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
|
@ -134,6 +134,7 @@ const sidebars = {
|
|||
"providers/huggingface",
|
||||
"providers/watsonx",
|
||||
"providers/predibase",
|
||||
"providers/triton-inference-server",
|
||||
"providers/ollama",
|
||||
"providers/perplexity",
|
||||
"providers/groq",
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
### Hide pydantic namespace conflict warnings globally ###
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
||||
### INIT VARIABLES ###
|
||||
import threading, requests, os
|
||||
|
@ -537,6 +538,7 @@ provider_list: List = [
|
|||
"xinference",
|
||||
"fireworks_ai",
|
||||
"watsonx",
|
||||
"triton",
|
||||
"predibase",
|
||||
"custom", # custom apis
|
||||
]
|
||||
|
|
119
litellm/llms/triton.py
Normal file
119
litellm/llms/triton.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, List
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx # type: ignore
|
||||
|
||||
|
||||
class TritonError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST",
|
||||
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class TritonChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
data: dict,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
api_base: str,
|
||||
logging_obj=None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
|
||||
async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
|
||||
response = await async_handler.post(url=api_base, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise TritonError(status_code=response.status_code, message=response.text)
|
||||
|
||||
_text_response = response.text
|
||||
|
||||
logging_obj.post_call(original_response=_text_response)
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
_outputs = _json_response["outputs"]
|
||||
_output_data = _outputs[0]["data"]
|
||||
_embedding_output = {
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": _output_data,
|
||||
}
|
||||
|
||||
model_response.model = _json_response.get("model_name", "None")
|
||||
model_response.data = [_embedding_output]
|
||||
|
||||
return model_response
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: list,
|
||||
timeout: float,
|
||||
api_base: str,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
api_key: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
):
|
||||
data_for_triton = {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "input_text",
|
||||
"shape": [1],
|
||||
"datatype": "BYTES",
|
||||
"data": input,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
|
||||
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
|
||||
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": curl_string,
|
||||
},
|
||||
)
|
||||
|
||||
if aembedding == True:
|
||||
response = self.aembedding(
|
||||
data=data_for_triton,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
raise Exception(
|
||||
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
||||
)
|
|
@ -47,6 +47,7 @@ from .llms import (
|
|||
ai21,
|
||||
sagemaker,
|
||||
bedrock,
|
||||
triton,
|
||||
huggingface_restapi,
|
||||
replicate,
|
||||
aleph_alpha,
|
||||
|
@ -75,6 +76,7 @@ from .llms.anthropic import AnthropicChatCompletion
|
|||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
|
@ -112,6 +114,7 @@ azure_chat_completions = AzureChatCompletion()
|
|||
azure_text_completions = AzureTextCompletion()
|
||||
huggingface = Huggingface()
|
||||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -2622,6 +2625,7 @@ async def aembedding(*args, **kwargs):
|
|||
or custom_llm_provider == "voyage"
|
||||
or custom_llm_provider == "mistral"
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "triton"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "openrouter"
|
||||
or custom_llm_provider == "deepinfra"
|
||||
|
@ -2955,6 +2959,23 @@ def embedding(
|
|||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
)
|
||||
elif custom_llm_provider == "triton":
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base is required for triton. Please pass `api_base`"
|
||||
)
|
||||
response = triton_chat_completions.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
timeout=timeout,
|
||||
model_response=EmbeddingResponse(),
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_project", None)
|
||||
|
|
|
@ -8,6 +8,10 @@ model_list:
|
|||
litellm_params:
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: my-triton-model
|
||||
litellm_params:
|
||||
model: triton/any"
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings
|
||||
|
||||
general_settings:
|
||||
store_model_in_db: true
|
||||
|
|
|
@ -516,6 +516,23 @@ def test_voyage_embeddings():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triton_embeddings():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.aembedding(
|
||||
model="triton/my-triton-model",
|
||||
api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
||||
# stubbed endpoint is setup to return this
|
||||
assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_voyage_embeddings()
|
||||
# def test_xinference_embeddings():
|
||||
# try:
|
||||
|
|
|
@ -4814,6 +4814,12 @@ def get_optional_params_embeddings(
|
|||
status_code=500,
|
||||
message=f"Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
||||
)
|
||||
if custom_llm_provider == "triton":
|
||||
keys = list(non_default_params.keys())
|
||||
for k in keys:
|
||||
non_default_params.pop(k, None)
|
||||
final_params = {**non_default_params, **kwargs}
|
||||
return final_params
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
if len(non_default_params.keys()) > 0:
|
||||
if litellm.drop_params is True: # drop the unsupported non-default values
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue