Krrish Dholakia 2024-12-11 01:03:57 -08:00
parent 5d1274cb6e
commit 06074bb13b
8 changed files with 197 additions and 62 deletions

View file

@ -1070,6 +1070,7 @@ from .llms.anthropic.experimental_pass_through.transformation import (
) )
from .llms.groq.stt.transformation import GroqSTTConfig from .llms.groq.stt.transformation import GroqSTTConfig
from .llms.anthropic.completion.transformation import AnthropicTextConfig from .llms.anthropic.completion.transformation import AnthropicTextConfig
from .llms.triton.completion.transformation import TritonConfig
from .llms.databricks.chat.transformation import DatabricksConfig from .llms.databricks.chat.transformation import DatabricksConfig
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
from .llms.predibase.chat.transformation import PredibaseConfig from .llms.predibase.chat.transformation import PredibaseConfig

View file

@ -0,0 +1,15 @@
from typing import Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
class TritonError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, httpx.Headers]] = None,
) -> None:
super().__init__(status_code=status_code, message=message, headers=headers)

View file

@ -5,12 +5,12 @@ from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore
import litellm import litellm
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
_get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.utils import ( from litellm.utils import (
@ -24,22 +24,9 @@ from litellm.utils import (
map_finish_reason, map_finish_reason,
) )
from .base import BaseLLM from ...base import BaseLLM
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory from ...prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import TritonError
class TritonError(Exception):
def __init__(self, status_code: int, message: str) -> None:
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST",
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class TritonChatCompletion(BaseLLM): class TritonChatCompletion(BaseLLM):
@ -142,31 +129,29 @@ class TritonChatCompletion(BaseLLM):
def completion( def completion(
self, self,
model: str, model: str,
messages: List[dict], messages: List,
timeout: float, timeout: float,
api_base: str, api_base: str,
logging_obj: Any, logging_obj: Any,
optional_params: dict, optional_params: dict,
litellm_params: dict,
model_response: ModelResponse, model_response: ModelResponse,
api_key: Optional[str] = None, api_key: Optional[str] = None,
client=None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
acompletion: bool = False, acompletion: bool = False,
headers: Optional[dict] = None,
) -> ModelResponse: ) -> ModelResponse:
type_of_model = "" type_of_model = ""
optional_params.pop("stream", False) optional_params.pop("stream", False)
if api_base.endswith("generate"): ### This is a trtllm model if api_base.endswith("generate"): ### This is a trtllm model
text_input = messages[0]["content"] data_for_triton = litellm.TritonConfig().transform_request(
data_for_triton: Dict[str, Any] = { model=model,
"text_input": prompt_factory(model=model, messages=messages), messages=messages,
"parameters": { optional_params=optional_params,
"max_tokens": int(optional_params.get("max_tokens", 2000)), litellm_params=litellm_params,
"bad_words": [""], headers=headers or {},
"stop_words": [""], )
},
"stream": bool(stream),
}
data_for_triton["parameters"].update(optional_params)
type_of_model = "trtllm" type_of_model = "trtllm"
elif api_base.endswith( elif api_base.endswith(
@ -226,7 +211,13 @@ class TritonChatCompletion(BaseLLM):
}, },
) )
headers = {"Content-Type": "application/json"} headers = litellm.TritonConfig().validate_environment(
headers=headers or {},
model=model,
messages=messages,
optional_params=optional_params,
api_key=api_key,
)
json_data_for_triton: str = json.dumps(data_for_triton) json_data_for_triton: str = json.dumps(data_for_triton)
if acompletion: if acompletion:
@ -240,8 +231,12 @@ class TritonChatCompletion(BaseLLM):
model_response=model_response, model_response=model_response,
type_of_model=type_of_model, type_of_model=type_of_model,
) )
if client is None or not isinstance(client, HTTPHandler):
handler = _get_httpx_client()
else: else:
handler = HTTPHandler() handler = client
if stream: if stream:
return self._handle_stream( # type: ignore return self._handle_stream( # type: ignore
handler, api_base, json_data_for_triton, model, logging_obj handler, api_base, json_data_for_triton, model, logging_obj

View file

@ -0,0 +1,92 @@
"""
Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint.
"""
from typing import Any, Dict, List, Optional, Union
from httpx import Headers, Response
from litellm.llms.base_llm.transformation import (
BaseConfig,
BaseLLMException,
LiteLLMLoggingObj,
)
from litellm.llms.prompt_templates.factory import prompt_factory
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ..common_utils import TritonError
class TritonConfig(BaseConfig):
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
inference_params = optional_params.copy()
stream = inference_params.pop("stream", False)
data_for_triton: Dict[str, Any] = {
"text_input": prompt_factory(model=model, messages=messages),
"parameters": {
"max_tokens": int(optional_params.get("max_tokens", 2000)),
"bad_words": [""],
"stop_words": [""],
},
"stream": bool(stream),
}
data_for_triton["parameters"].update(inference_params)
return data_for_triton
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
) -> BaseLLMException:
return TritonError(
status_code=status_code, message=error_message, headers=headers
)
def get_supported_openai_params(self, model: str) -> List:
return ["max_tokens", "max_completion_tokens"]
def map_openai_params(
self,
non_default_params: Dict,
optional_params: Dict,
model: str,
drop_params: bool,
) -> Dict:
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params[param] = value
return optional_params
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: Dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"response transformation done in handler.py. [TODO] Migrate here."
)
def validate_environment(
self,
headers: Dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
) -> Dict:
return {"Content-Type": "application/json"}

View file

@ -126,7 +126,7 @@ from .llms.sagemaker.chat.handler import SagemakerChatHandler
from .llms.sagemaker.completion.handler import SagemakerLLM from .llms.sagemaker.completion.handler import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.together_ai.completion.handler import TogetherAITextCompletion from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton import TritonChatCompletion from .llms.triton.completion.handler import TritonChatCompletion
from .llms.vertex_ai import vertex_ai_non_gemini from .llms.vertex_ai import vertex_ai_non_gemini
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import ( from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
@ -559,7 +559,9 @@ def mock_completion(
raise litellm.MockException( raise litellm.MockException(
status_code=getattr(mock_response, "status_code", 500), # type: ignore status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=getattr(mock_response, "text", str(mock_response)), message=getattr(mock_response, "text", str(mock_response)),
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model, # type: ignore model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"), request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
) )
@ -568,7 +570,9 @@ def mock_completion(
): ):
raise litellm.RateLimitError( raise litellm.RateLimitError(
message="this is a mock rate limit error", message="this is a mock rate limit error",
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model, model=model,
) )
elif ( elif (
@ -577,7 +581,9 @@ def mock_completion(
): ):
raise litellm.InternalServerError( raise litellm.InternalServerError(
message="this is a mock internal server error", message="this is a mock internal server error",
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model, model=model,
) )
elif isinstance(mock_response, str) and mock_response.startswith( elif isinstance(mock_response, str) and mock_response.startswith(
@ -2374,7 +2380,6 @@ def completion( # type: ignore # noqa: PLR0915
return _model_response return _model_response
response = _model_response response = _model_response
elif custom_llm_provider == "text-completion-codestral": elif custom_llm_provider == "text-completion-codestral":
api_base = ( api_base = (
api_base api_base
or optional_params.pop("api_base", None) or optional_params.pop("api_base", None)
@ -2705,6 +2710,8 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging, logging_obj=logging,
stream=stream, stream=stream,
acompletion=acompletion, acompletion=acompletion,
client=client,
litellm_params=litellm_params,
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
@ -2944,7 +2951,9 @@ def completion_with_retries(*args, **kwargs):
) )
num_retries = kwargs.pop("num_retries", 3) num_retries = kwargs.pop("num_retries", 3)
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop(
"retry_strategy", "constant_retry"
) # type: ignore
original_function = kwargs.pop("original_function", completion) original_function = kwargs.pop("original_function", completion)
if retry_strategy == "exponential_backoff_retry": if retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying( retryer = tenacity.Retrying(
@ -3331,9 +3340,7 @@ def embedding( # noqa: PLR0915
max_retries=max_retries, max_retries=max_retries,
) )
elif custom_llm_provider == "databricks": elif custom_llm_provider == "databricks":
api_base = ( api_base = api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") # type: ignore
api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE")
) # type: ignore
# set API KEY # set API KEY
api_key = ( api_key = (
@ -3465,7 +3472,6 @@ def embedding( # noqa: PLR0915
aembedding=aembedding, aembedding=aembedding,
) )
elif custom_llm_provider == "gemini": elif custom_llm_provider == "gemini":
gemini_api_key = ( gemini_api_key = (
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
) )
@ -3960,7 +3966,11 @@ def text_completion( # noqa: PLR0915
optional_params["custom_llm_provider"] = custom_llm_provider optional_params["custom_llm_provider"] = custom_llm_provider
# get custom_llm_provider # get custom_llm_provider
_model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore _model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, # type: ignore
custom_llm_provider=custom_llm_provider,
api_base=api_base,
)
if custom_llm_provider == "huggingface": if custom_llm_provider == "huggingface":
# if echo == True, for TGI llms we need to set top_n_tokens to 3 # if echo == True, for TGI llms we need to set top_n_tokens to 3
@ -4212,7 +4222,6 @@ async def amoderation(
) )
openai_client = kwargs.get("client", None) openai_client = kwargs.get("client", None)
if openai_client is None or not isinstance(openai_client, AsyncOpenAI): if openai_client is None or not isinstance(openai_client, AsyncOpenAI):
# call helper to get OpenAI client # call helper to get OpenAI client
# _get_openai_client maintains in-memory caching logic for OpenAI clients # _get_openai_client maintains in-memory caching logic for OpenAI clients
_openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore _openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore
@ -4322,7 +4331,11 @@ def image_generation( # noqa: PLR0915
headers.update(extra_headers) headers.update(extra_headers)
model_response: ImageResponse = litellm.utils.ImageResponse() model_response: ImageResponse = litellm.utils.ImageResponse()
if model is not None or custom_llm_provider is not None: if model is not None or custom_llm_provider is not None:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, # type: ignore
custom_llm_provider=custom_llm_provider,
api_base=api_base,
)
else: else:
model = "dall-e-2" model = "dall-e-2"
custom_llm_provider = "openai" # default to dall-e-2 on openai custom_llm_provider = "openai" # default to dall-e-2 on openai
@ -4644,7 +4657,9 @@ def transcription(
model_response = litellm.utils.TranscriptionResponse() model_response = litellm.utils.TranscriptionResponse()
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider, api_base=api_base
) # type: ignore
if dynamic_api_key is not None: if dynamic_api_key is not None:
api_key = dynamic_api_key api_key = dynamic_api_key
@ -4710,12 +4725,7 @@ def transcription(
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
) )
# set API KEY # set API KEY
api_key = ( api_key = api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") # type: ignore
api_key
or litellm.api_key
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
) # type: ignore
response = openai_audio_transcriptions.audio_transcriptions( response = openai_audio_transcriptions.audio_transcriptions(
model=model, model=model,
audio_file=file, audio_file=file,
@ -4802,7 +4812,9 @@ def speech(
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
extra_headers = kwargs.get("extra_headers", None) extra_headers = kwargs.get("extra_headers", None)
model_info = kwargs.get("model_info", None) model_info = kwargs.get("model_info", None)
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider, api_base=api_base
) # type: ignore
kwargs.pop("tags", []) kwargs.pop("tags", [])
optional_params = {} optional_params = {}
@ -4895,9 +4907,7 @@ def speech(
) )
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_version = ( api_version = api_version or litellm.api_version or get_secret("AZURE_API_VERSION") # type: ignore
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = ( api_key = (
api_key api_key
@ -5004,7 +5014,6 @@ async def ahealth_check( # noqa: PLR0915
""" """
passed_in_mode: Optional[str] = None passed_in_mode: Optional[str] = None
try: try:
model: Optional[str] = model_params.get("model", None) model: Optional[str] = model_params.get("model", None)
if model is None: if model is None:

View file

@ -6373,6 +6373,8 @@ class ProviderConfigManager:
return litellm.OllamaConfig() return litellm.OllamaConfig()
elif litellm.LlmProviders.PREDIBASE == provider: elif litellm.LlmProviders.PREDIBASE == provider:
return litellm.PredibaseConfig() return litellm.PredibaseConfig()
elif litellm.LlmProviders.TRITON == provider:
return litellm.TritonConfig()
return litellm.OpenAIGPTConfig() return litellm.OpenAIGPTConfig()

View file

@ -322,11 +322,7 @@ def _check_provider_config(config: BaseConfig, provider: LlmProviders):
# or provider == LlmProviders.VERTEX_AI_BETA # or provider == LlmProviders.VERTEX_AI_BETA
# or provider == LlmProviders.BEDROCK # or provider == LlmProviders.BEDROCK
# or provider == LlmProviders.BASETEN # or provider == LlmProviders.BASETEN
# or provider == LlmProviders.SAGEMAKER
# or provider == LlmProviders.SAGEMAKER_CHAT
# or provider == LlmProviders.VLLM
# or provider == LlmProviders.PETALS # or provider == LlmProviders.PETALS
# or provider == LlmProviders.OLLAMA
# ): # ):
# continue # continue
# config = ProviderConfigManager.get_provider_chat_config( # config = ProviderConfigManager.get_provider_chat_config(

View file

@ -1,5 +1,5 @@
import pytest import pytest
from litellm.llms.triton import TritonChatCompletion from litellm.llms.triton.completion.handler import TritonChatCompletion
def test_split_embedding_by_shape_passes(): def test_split_embedding_by_shape_passes():
@ -29,3 +29,28 @@ def test_split_embedding_by_shape_fails_with_shape_value_error():
] ]
with pytest.raises(ValueError): with pytest.raises(ValueError):
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"]) triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
def test_completion_triton():
from litellm import completion
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, MagicMock, AsyncMock
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
response = completion(
model="triton/llama-3-8b-instruct",
messages=[{"role": "user", "content": "who are u?"}],
max_tokens=10,
timeout=5,
client=client,
api_base="http://localhost:8000/generate",
)
print(response)
except Exception as e:
print(e)
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)