Merge pull request #3577 from BerriAI/litellm_add_triton_server

[Feat] Add Triton Embeddings to LiteLLM
This commit is contained in:
Ishaan Jaff 2024-05-10 19:20:23 -07:00 committed by GitHub
commit b09075da53
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 265 additions and 0 deletions

View file

@ -0,0 +1,95 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Triton Inference Server
LiteLLM supports Embedding Models on Triton Inference Servers
## Usage
<Tabs>
<TabItem value="sdk" label="SDK">
### Example Call
Use the `triton/` prefix to route to triton server
```python
from litellm import embedding
import os
response = await litellm.aembedding(
model="triton/<your-triton-model>",
api_base="https://your-triton-api-base/triton/embeddings", # /embeddings endpoint you want litellm to call on your server
input=["good morning from litellm"],
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: my-triton-model
litellm_params:
model: triton/<your-triton-model>"
api_base: https://your-triton-api-base/triton/embeddings
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --detailed_debug
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
from openai import OpenAI
# set base_url to your proxy server
# set api_key to send to proxy server
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
response = client.embeddings.create(
input=["hello from litellm"],
model="my-triton-model"
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
```shell
curl --location 'http://0.0.0.0:4000/embeddings' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data ' {
"model": "my-triton-model",
"input": ["write a litellm poem"]
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>

View file

@ -134,6 +134,7 @@ const sidebars = {
"providers/huggingface",
"providers/watsonx",
"providers/predibase",
"providers/triton-inference-server",
"providers/ollama",
"providers/perplexity",
"providers/groq",

View file

@ -1,5 +1,6 @@
### Hide pydantic namespace conflict warnings globally ###
import warnings
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
### INIT VARIABLES ###
import threading, requests, os
@ -537,6 +538,7 @@ provider_list: List = [
"xinference",
"fireworks_ai",
"watsonx",
"triton",
"predibase",
"custom", # custom apis
]

119
litellm/llms/triton.py Normal file
View file

@ -0,0 +1,119 @@
import os, types
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
class TritonError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST",
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class TritonChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
async def aembedding(
self,
data: dict,
model_response: litellm.utils.EmbeddingResponse,
api_base: str,
logging_obj=None,
api_key: Optional[str] = None,
):
async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await async_handler.post(url=api_base, data=json.dumps(data))
if response.status_code != 200:
raise TritonError(status_code=response.status_code, message=response.text)
_text_response = response.text
logging_obj.post_call(original_response=_text_response)
_json_response = response.json()
_outputs = _json_response["outputs"]
_output_data = _outputs[0]["data"]
_embedding_output = {
"object": "embedding",
"index": 0,
"embedding": _output_data,
}
model_response.model = _json_response.get("model_name", "None")
model_response.data = [_embedding_output]
return model_response
def embedding(
self,
model: str,
input: list,
timeout: float,
api_base: str,
model_response: litellm.utils.EmbeddingResponse,
api_key: Optional[str] = None,
logging_obj=None,
optional_params=None,
client=None,
aembedding=None,
):
data_for_triton = {
"inputs": [
{
"name": "input_text",
"shape": [1],
"datatype": "BYTES",
"data": input,
}
]
}
## LOGGING
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
logging_obj.pre_call(
input="",
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": curl_string,
},
)
if aembedding == True:
response = self.aembedding(
data=data_for_triton,
model_response=model_response,
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
)
return response
else:
raise Exception(
"Only async embedding supported for triton, please use litellm.aembedding() for now"
)

View file

@ -47,6 +47,7 @@ from .llms import (
ai21,
sagemaker,
bedrock,
triton,
huggingface_restapi,
replicate,
aleph_alpha,
@ -75,6 +76,7 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
prompt_factory,
custom_prompt,
@ -112,6 +114,7 @@ azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion()
huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
####### COMPLETION ENDPOINTS ################
@ -2622,6 +2625,7 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "voyage"
or custom_llm_provider == "mistral"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "triton"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "openrouter"
or custom_llm_provider == "deepinfra"
@ -2955,6 +2959,23 @@ def embedding(
optional_params=optional_params,
model_response=EmbeddingResponse(),
)
elif custom_llm_provider == "triton":
if api_base is None:
raise ValueError(
"api_base is required for triton. Please pass `api_base`"
)
response = triton_chat_completions.embedding(
model=model,
input=input,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (
optional_params.pop("vertex_project", None)

View file

@ -8,6 +8,10 @@ model_list:
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
- model_name: my-triton-model
litellm_params:
model: triton/any"
api_base: https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings
general_settings:
store_model_in_db: true

View file

@ -516,6 +516,23 @@ def test_voyage_embeddings():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_triton_embeddings():
try:
litellm.set_verbose = True
response = await litellm.aembedding(
model="triton/my-triton-model",
api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings",
input=["good morning from litellm"],
)
print(f"response: {response}")
# stubbed endpoint is setup to return this
assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_voyage_embeddings()
# def test_xinference_embeddings():
# try:

View file

@ -4814,6 +4814,12 @@ def get_optional_params_embeddings(
status_code=500,
message=f"Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
)
if custom_llm_provider == "triton":
keys = list(non_default_params.keys())
for k in keys:
non_default_params.pop(k, None)
final_params = {**non_default_params, **kwargs}
return final_params
if custom_llm_provider == "vertex_ai":
if len(non_default_params.keys()) > 0:
if litellm.drop_params is True: # drop the unsupported non-default values