forked from phoenix/litellm-mirror
Merge pull request #3577 from BerriAI/litellm_add_triton_server
[Feat] Add Triton Embeddings to LiteLLM
This commit is contained in:
commit
b09075da53
8 changed files with 265 additions and 0 deletions
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# Triton Inference Server
|
||||||
|
|
||||||
|
LiteLLM supports Embedding Models on Triton Inference Servers
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
|
||||||
|
### Example Call
|
||||||
|
|
||||||
|
Use the `triton/` prefix to route to triton server
|
||||||
|
```python
|
||||||
|
from litellm import embedding
|
||||||
|
import os
|
||||||
|
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="triton/<your-triton-model>",
|
||||||
|
api_base="https://your-triton-api-base/triton/embeddings", # /embeddings endpoint you want litellm to call on your server
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: my-triton-model
|
||||||
|
litellm_params:
|
||||||
|
model: triton/<your-triton-model>"
|
||||||
|
api_base: https://your-triton-api-base/triton/embeddings
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --detailed_debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Send Request to LiteLLM Proxy Server
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# set base_url to your proxy server
|
||||||
|
# set api_key to send to proxy server
|
||||||
|
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
response = client.embeddings.create(
|
||||||
|
input=["hello from litellm"],
|
||||||
|
model="my-triton-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/embeddings' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--data ' {
|
||||||
|
"model": "my-triton-model",
|
||||||
|
"input": ["write a litellm poem"]
|
||||||
|
}'
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
|
@ -134,6 +134,7 @@ const sidebars = {
|
||||||
"providers/huggingface",
|
"providers/huggingface",
|
||||||
"providers/watsonx",
|
"providers/watsonx",
|
||||||
"providers/predibase",
|
"providers/predibase",
|
||||||
|
"providers/triton-inference-server",
|
||||||
"providers/ollama",
|
"providers/ollama",
|
||||||
"providers/perplexity",
|
"providers/perplexity",
|
||||||
"providers/groq",
|
"providers/groq",
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
### Hide pydantic namespace conflict warnings globally ###
|
### Hide pydantic namespace conflict warnings globally ###
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
||||||
### INIT VARIABLES ###
|
### INIT VARIABLES ###
|
||||||
import threading, requests, os
|
import threading, requests, os
|
||||||
|
@ -537,6 +538,7 @@ provider_list: List = [
|
||||||
"xinference",
|
"xinference",
|
||||||
"fireworks_ai",
|
"fireworks_ai",
|
||||||
"watsonx",
|
"watsonx",
|
||||||
|
"triton",
|
||||||
"predibase",
|
"predibase",
|
||||||
"custom", # custom apis
|
"custom", # custom apis
|
||||||
]
|
]
|
||||||
|
|
119
litellm/llms/triton.py
Normal file
119
litellm/llms/triton.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
import os, types
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests, copy # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional, List
|
||||||
|
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||||
|
import litellm
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class TritonError(Exception):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
|
||||||
|
)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class TritonChatCompletion(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def aembedding(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
|
api_base: str,
|
||||||
|
logging_obj=None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await async_handler.post(url=api_base, data=json.dumps(data))
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise TritonError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
_text_response = response.text
|
||||||
|
|
||||||
|
logging_obj.post_call(original_response=_text_response)
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
|
||||||
|
_outputs = _json_response["outputs"]
|
||||||
|
_output_data = _outputs[0]["data"]
|
||||||
|
_embedding_output = {
|
||||||
|
"object": "embedding",
|
||||||
|
"index": 0,
|
||||||
|
"embedding": _output_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_response.model = _json_response.get("model_name", "None")
|
||||||
|
model_response.data = [_embedding_output]
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: list,
|
||||||
|
timeout: float,
|
||||||
|
api_base: str,
|
||||||
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
optional_params=None,
|
||||||
|
client=None,
|
||||||
|
aembedding=None,
|
||||||
|
):
|
||||||
|
data_for_triton = {
|
||||||
|
"inputs": [
|
||||||
|
{
|
||||||
|
"name": "input_text",
|
||||||
|
"shape": [1],
|
||||||
|
"datatype": "BYTES",
|
||||||
|
"data": input,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
|
||||||
|
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
|
||||||
|
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input="",
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": curl_string,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if aembedding == True:
|
||||||
|
response = self.aembedding(
|
||||||
|
data=data_for_triton,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
||||||
|
)
|
|
@ -47,6 +47,7 @@ from .llms import (
|
||||||
ai21,
|
ai21,
|
||||||
sagemaker,
|
sagemaker,
|
||||||
bedrock,
|
bedrock,
|
||||||
|
triton,
|
||||||
huggingface_restapi,
|
huggingface_restapi,
|
||||||
replicate,
|
replicate,
|
||||||
aleph_alpha,
|
aleph_alpha,
|
||||||
|
@ -75,6 +76,7 @@ from .llms.anthropic import AnthropicChatCompletion
|
||||||
from .llms.anthropic_text import AnthropicTextCompletion
|
from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
|
@ -112,6 +114,7 @@ azure_chat_completions = AzureChatCompletion()
|
||||||
azure_text_completions = AzureTextCompletion()
|
azure_text_completions = AzureTextCompletion()
|
||||||
huggingface = Huggingface()
|
huggingface = Huggingface()
|
||||||
predibase_chat_completions = PredibaseChatCompletion()
|
predibase_chat_completions = PredibaseChatCompletion()
|
||||||
|
triton_chat_completions = TritonChatCompletion()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
||||||
|
@ -2622,6 +2625,7 @@ async def aembedding(*args, **kwargs):
|
||||||
or custom_llm_provider == "voyage"
|
or custom_llm_provider == "voyage"
|
||||||
or custom_llm_provider == "mistral"
|
or custom_llm_provider == "mistral"
|
||||||
or custom_llm_provider == "custom_openai"
|
or custom_llm_provider == "custom_openai"
|
||||||
|
or custom_llm_provider == "triton"
|
||||||
or custom_llm_provider == "anyscale"
|
or custom_llm_provider == "anyscale"
|
||||||
or custom_llm_provider == "openrouter"
|
or custom_llm_provider == "openrouter"
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
|
@ -2955,6 +2959,23 @@ def embedding(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "triton":
|
||||||
|
if api_base is None:
|
||||||
|
raise ValueError(
|
||||||
|
"api_base is required for triton. Please pass `api_base`"
|
||||||
|
)
|
||||||
|
response = triton_chat_completions.embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging,
|
||||||
|
timeout=timeout,
|
||||||
|
model_response=EmbeddingResponse(),
|
||||||
|
optional_params=optional_params,
|
||||||
|
client=client,
|
||||||
|
aembedding=aembedding,
|
||||||
|
)
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
vertex_ai_project = (
|
vertex_ai_project = (
|
||||||
optional_params.pop("vertex_project", None)
|
optional_params.pop("vertex_project", None)
|
||||||
|
|
|
@ -8,6 +8,10 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/*
|
model: openai/*
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
- model_name: my-triton-model
|
||||||
|
litellm_params:
|
||||||
|
model: triton/any"
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
store_model_in_db: true
|
store_model_in_db: true
|
||||||
|
|
|
@ -516,6 +516,23 @@ def test_voyage_embeddings():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_triton_embeddings():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="triton/my-triton-model",
|
||||||
|
api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
# stubbed endpoint is setup to return this
|
||||||
|
assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_voyage_embeddings()
|
# test_voyage_embeddings()
|
||||||
# def test_xinference_embeddings():
|
# def test_xinference_embeddings():
|
||||||
# try:
|
# try:
|
||||||
|
|
|
@ -4814,6 +4814,12 @@ def get_optional_params_embeddings(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
message=f"Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
message=f"Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
||||||
)
|
)
|
||||||
|
if custom_llm_provider == "triton":
|
||||||
|
keys = list(non_default_params.keys())
|
||||||
|
for k in keys:
|
||||||
|
non_default_params.pop(k, None)
|
||||||
|
final_params = {**non_default_params, **kwargs}
|
||||||
|
return final_params
|
||||||
if custom_llm_provider == "vertex_ai":
|
if custom_llm_provider == "vertex_ai":
|
||||||
if len(non_default_params.keys()) > 0:
|
if len(non_default_params.keys()) > 0:
|
||||||
if litellm.drop_params is True: # drop the unsupported non-default values
|
if litellm.drop_params is True: # drop the unsupported non-default values
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue