diff --git a/docs/my-website/docs/providers/triton-inference-server.md b/docs/my-website/docs/providers/triton-inference-server.md new file mode 100644 index 000000000..aacc46a39 --- /dev/null +++ b/docs/my-website/docs/providers/triton-inference-server.md @@ -0,0 +1,95 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Triton Inference Server + +LiteLLM supports Embedding Models on Triton Inference Servers + + +## Usage + + + + + +### Example Call + +Use the `triton/` prefix to route to triton server +```python +from litellm import embedding +import os + +response = await litellm.aembedding( + model="triton/", + api_base="https://your-triton-api-base/triton/embeddings", # /embeddings endpoint you want litellm to call on your server + input=["good morning from litellm"], +) +``` + + + + +1. Add models to your config.yaml + + ```yaml + model_list: + - model_name: my-triton-model + litellm_params: + model: triton/" + api_base: https://your-triton-api-base/triton/embeddings + ``` + + +2. Start the proxy + + ```bash + $ litellm --config /path/to/config.yaml --detailed_debug + ``` + +3. Send Request to LiteLLM Proxy Server + + + + + + ```python + import openai + from openai import OpenAI + + # set base_url to your proxy server + # set api_key to send to proxy server + client = OpenAI(api_key="", base_url="http://0.0.0.0:4000") + + response = client.embeddings.create( + input=["hello from litellm"], + model="my-triton-model" + ) + + print(response) + + ``` + + + + + + `--header` is optional, only required if you're using litellm proxy with Virtual Keys + + ```shell + curl --location 'http://0.0.0.0:4000/embeddings' \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer sk-1234' \ + --data ' { + "model": "my-triton-model", + "input": ["write a litellm poem"] + }' + + ``` + + + + + + + + diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 4ce587080..3bb5dc88e 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -134,6 +134,7 @@ const sidebars = { "providers/huggingface", "providers/watsonx", "providers/predibase", + "providers/triton-inference-server", "providers/ollama", "providers/perplexity", "providers/groq", diff --git a/litellm/__init__.py b/litellm/__init__.py index 8d4666a3b..aedf42139 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1,5 +1,6 @@ ### Hide pydantic namespace conflict warnings globally ### import warnings + warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*") ### INIT VARIABLES ### import threading, requests, os @@ -537,6 +538,7 @@ provider_list: List = [ "xinference", "fireworks_ai", "watsonx", + "triton", "predibase", "custom", # custom apis ] diff --git a/litellm/llms/triton.py b/litellm/llms/triton.py new file mode 100644 index 000000000..711186b3f --- /dev/null +++ b/litellm/llms/triton.py @@ -0,0 +1,119 @@ +import os, types +import json +from enum import Enum +import requests, copy # type: ignore +import time +from typing import Callable, Optional, List +from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper +import litellm +from .prompt_templates.factory import prompt_factory, custom_prompt +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from .base import BaseLLM +import httpx # type: ignore + + +class TritonError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", + url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class TritonChatCompletion(BaseLLM): + def __init__(self) -> None: + super().__init__() + + async def aembedding( + self, + data: dict, + model_response: litellm.utils.EmbeddingResponse, + api_base: str, + logging_obj=None, + api_key: Optional[str] = None, + ): + + async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + + response = await async_handler.post(url=api_base, data=json.dumps(data)) + + if response.status_code != 200: + raise TritonError(status_code=response.status_code, message=response.text) + + _text_response = response.text + + logging_obj.post_call(original_response=_text_response) + + _json_response = response.json() + + _outputs = _json_response["outputs"] + _output_data = _outputs[0]["data"] + _embedding_output = { + "object": "embedding", + "index": 0, + "embedding": _output_data, + } + + model_response.model = _json_response.get("model_name", "None") + model_response.data = [_embedding_output] + + return model_response + + def embedding( + self, + model: str, + input: list, + timeout: float, + api_base: str, + model_response: litellm.utils.EmbeddingResponse, + api_key: Optional[str] = None, + logging_obj=None, + optional_params=None, + client=None, + aembedding=None, + ): + data_for_triton = { + "inputs": [ + { + "name": "input_text", + "shape": [1], + "datatype": "BYTES", + "data": input, + } + ] + } + + ## LOGGING + + curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'" + + logging_obj.pre_call( + input="", + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": curl_string, + }, + ) + + if aembedding == True: + response = self.aembedding( + data=data_for_triton, + model_response=model_response, + logging_obj=logging_obj, + api_base=api_base, + api_key=api_key, + ) + return response + else: + raise Exception( + "Only async embedding supported for triton, please use litellm.aembedding() for now" + ) diff --git a/litellm/main.py b/litellm/main.py index 72f5b1dc6..3f8c659f9 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -47,6 +47,7 @@ from .llms import ( ai21, sagemaker, bedrock, + triton, huggingface_restapi, replicate, aleph_alpha, @@ -75,6 +76,7 @@ from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic_text import AnthropicTextCompletion from .llms.huggingface_restapi import Huggingface from .llms.predibase import PredibaseChatCompletion +from .llms.triton import TritonChatCompletion from .llms.prompt_templates.factory import ( prompt_factory, custom_prompt, @@ -112,6 +114,7 @@ azure_chat_completions = AzureChatCompletion() azure_text_completions = AzureTextCompletion() huggingface = Huggingface() predibase_chat_completions = PredibaseChatCompletion() +triton_chat_completions = TritonChatCompletion() ####### COMPLETION ENDPOINTS ################ @@ -2622,6 +2625,7 @@ async def aembedding(*args, **kwargs): or custom_llm_provider == "voyage" or custom_llm_provider == "mistral" or custom_llm_provider == "custom_openai" + or custom_llm_provider == "triton" or custom_llm_provider == "anyscale" or custom_llm_provider == "openrouter" or custom_llm_provider == "deepinfra" @@ -2955,6 +2959,23 @@ def embedding( optional_params=optional_params, model_response=EmbeddingResponse(), ) + elif custom_llm_provider == "triton": + if api_base is None: + raise ValueError( + "api_base is required for triton. Please pass `api_base`" + ) + response = triton_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) elif custom_llm_provider == "vertex_ai": vertex_ai_project = ( optional_params.pop("vertex_project", None) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 0378efced..b1cbf2e81 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -8,6 +8,10 @@ model_list: litellm_params: model: openai/* api_key: os.environ/OPENAI_API_KEY + - model_name: my-triton-model + litellm_params: + model: triton/any" + api_base: https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings general_settings: store_model_in_db: true diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 69ddb39ff..8da847b64 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -516,6 +516,23 @@ def test_voyage_embeddings(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.asyncio +async def test_triton_embeddings(): + try: + litellm.set_verbose = True + response = await litellm.aembedding( + model="triton/my-triton-model", + api_base="https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings", + input=["good morning from litellm"], + ) + print(f"response: {response}") + + # stubbed endpoint is setup to return this + assert response.data[0]["embedding"] == [0.1, 0.2, 0.3] + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + # test_voyage_embeddings() # def test_xinference_embeddings(): # try: diff --git a/litellm/utils.py b/litellm/utils.py index 09a7851dc..9218f92a3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4814,6 +4814,12 @@ def get_optional_params_embeddings( status_code=500, message=f"Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.", ) + if custom_llm_provider == "triton": + keys = list(non_default_params.keys()) + for k in keys: + non_default_params.pop(k, None) + final_params = {**non_default_params, **kwargs} + return final_params if custom_llm_provider == "vertex_ai": if len(non_default_params.keys()) > 0: if litellm.drop_params is True: # drop the unsupported non-default values