Krrish Dholakia 2024-12-11 01:03:57 -08:00
parent 5d1274cb6e
commit 06074bb13b
8 changed files with 197 additions and 62 deletions

View file

@ -1070,6 +1070,7 @@ from .llms.anthropic.experimental_pass_through.transformation import (
)
from .llms.groq.stt.transformation import GroqSTTConfig
from .llms.anthropic.completion.transformation import AnthropicTextConfig
from .llms.triton.completion.transformation import TritonConfig
from .llms.databricks.chat.transformation import DatabricksConfig
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
from .llms.predibase.chat.transformation import PredibaseConfig

View file

@ -0,0 +1,15 @@
from typing import Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
class TritonError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, httpx.Headers]] = None,
) -> None:
super().__init__(status_code=status_code, message=message, headers=headers)

View file

@ -5,12 +5,12 @@ from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.utils import (
@ -24,22 +24,9 @@ from litellm.utils import (
map_finish_reason,
)
from .base import BaseLLM
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
class TritonError(Exception):
def __init__(self, status_code: int, message: str) -> None:
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST",
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
from ...base import BaseLLM
from ...prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import TritonError
class TritonChatCompletion(BaseLLM):
@ -142,31 +129,29 @@ class TritonChatCompletion(BaseLLM):
def completion(
self,
model: str,
messages: List[dict],
messages: List,
timeout: float,
api_base: str,
logging_obj: Any,
optional_params: dict,
litellm_params: dict,
model_response: ModelResponse,
api_key: Optional[str] = None,
client=None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
stream: Optional[bool] = False,
acompletion: bool = False,
headers: Optional[dict] = None,
) -> ModelResponse:
type_of_model = ""
optional_params.pop("stream", False)
if api_base.endswith("generate"): ### This is a trtllm model
text_input = messages[0]["content"]
data_for_triton: Dict[str, Any] = {
"text_input": prompt_factory(model=model, messages=messages),
"parameters": {
"max_tokens": int(optional_params.get("max_tokens", 2000)),
"bad_words": [""],
"stop_words": [""],
},
"stream": bool(stream),
}
data_for_triton["parameters"].update(optional_params)
data_for_triton = litellm.TritonConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers or {},
)
type_of_model = "trtllm"
elif api_base.endswith(
@ -226,7 +211,13 @@ class TritonChatCompletion(BaseLLM):
},
)
headers = {"Content-Type": "application/json"}
headers = litellm.TritonConfig().validate_environment(
headers=headers or {},
model=model,
messages=messages,
optional_params=optional_params,
api_key=api_key,
)
json_data_for_triton: str = json.dumps(data_for_triton)
if acompletion:
@ -240,8 +231,12 @@ class TritonChatCompletion(BaseLLM):
model_response=model_response,
type_of_model=type_of_model,
)
if client is None or not isinstance(client, HTTPHandler):
handler = _get_httpx_client()
else:
handler = HTTPHandler()
handler = client
if stream:
return self._handle_stream( # type: ignore
handler, api_base, json_data_for_triton, model, logging_obj

View file

@ -0,0 +1,92 @@
"""
Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint.
"""
from typing import Any, Dict, List, Optional, Union
from httpx import Headers, Response
from litellm.llms.base_llm.transformation import (
BaseConfig,
BaseLLMException,
LiteLLMLoggingObj,
)
from litellm.llms.prompt_templates.factory import prompt_factory
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ..common_utils import TritonError
class TritonConfig(BaseConfig):
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
inference_params = optional_params.copy()
stream = inference_params.pop("stream", False)
data_for_triton: Dict[str, Any] = {
"text_input": prompt_factory(model=model, messages=messages),
"parameters": {
"max_tokens": int(optional_params.get("max_tokens", 2000)),
"bad_words": [""],
"stop_words": [""],
},
"stream": bool(stream),
}
data_for_triton["parameters"].update(inference_params)
return data_for_triton
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
) -> BaseLLMException:
return TritonError(
status_code=status_code, message=error_message, headers=headers
)
def get_supported_openai_params(self, model: str) -> List:
return ["max_tokens", "max_completion_tokens"]
def map_openai_params(
self,
non_default_params: Dict,
optional_params: Dict,
model: str,
drop_params: bool,
) -> Dict:
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params[param] = value
return optional_params
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: Dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"response transformation done in handler.py. [TODO] Migrate here."
)
def validate_environment(
self,
headers: Dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
) -> Dict:
return {"Content-Type": "application/json"}

View file

@ -126,7 +126,7 @@ from .llms.sagemaker.chat.handler import SagemakerChatHandler
from .llms.sagemaker.completion.handler import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton import TritonChatCompletion
from .llms.triton.completion.handler import TritonChatCompletion
from .llms.vertex_ai import vertex_ai_non_gemini
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
@ -559,7 +559,9 @@ def mock_completion(
raise litellm.MockException(
status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=getattr(mock_response, "text", str(mock_response)),
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
@ -568,7 +570,9 @@ def mock_completion(
):
raise litellm.RateLimitError(
message="this is a mock rate limit error",
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model,
)
elif (
@ -577,7 +581,9 @@ def mock_completion(
):
raise litellm.InternalServerError(
message="this is a mock internal server error",
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model,
)
elif isinstance(mock_response, str) and mock_response.startswith(
@ -2374,7 +2380,6 @@ def completion( # type: ignore # noqa: PLR0915
return _model_response
response = _model_response
elif custom_llm_provider == "text-completion-codestral":
api_base = (
api_base
or optional_params.pop("api_base", None)
@ -2705,6 +2710,8 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging,
stream=stream,
acompletion=acompletion,
client=client,
litellm_params=litellm_params,
)
## RESPONSE OBJECT
@ -2944,7 +2951,9 @@ def completion_with_retries(*args, **kwargs):
)
num_retries = kwargs.pop("num_retries", 3)
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop(
"retry_strategy", "constant_retry"
) # type: ignore
original_function = kwargs.pop("original_function", completion)
if retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying(
@ -3331,9 +3340,7 @@ def embedding( # noqa: PLR0915
max_retries=max_retries,
)
elif custom_llm_provider == "databricks":
api_base = (
api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE")
) # type: ignore
api_base = api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") # type: ignore
# set API KEY
api_key = (
@ -3465,7 +3472,6 @@ def embedding( # noqa: PLR0915
aembedding=aembedding,
)
elif custom_llm_provider == "gemini":
gemini_api_key = (
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
)
@ -3960,7 +3966,11 @@ def text_completion( # noqa: PLR0915
optional_params["custom_llm_provider"] = custom_llm_provider
# get custom_llm_provider
_model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
_model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, # type: ignore
custom_llm_provider=custom_llm_provider,
api_base=api_base,
)
if custom_llm_provider == "huggingface":
# if echo == True, for TGI llms we need to set top_n_tokens to 3
@ -4212,7 +4222,6 @@ async def amoderation(
)
openai_client = kwargs.get("client", None)
if openai_client is None or not isinstance(openai_client, AsyncOpenAI):
# call helper to get OpenAI client
# _get_openai_client maintains in-memory caching logic for OpenAI clients
_openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore
@ -4322,7 +4331,11 @@ def image_generation( # noqa: PLR0915
headers.update(extra_headers)
model_response: ImageResponse = litellm.utils.ImageResponse()
if model is not None or custom_llm_provider is not None:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, # type: ignore
custom_llm_provider=custom_llm_provider,
api_base=api_base,
)
else:
model = "dall-e-2"
custom_llm_provider = "openai" # default to dall-e-2 on openai
@ -4644,7 +4657,9 @@ def transcription(
model_response = litellm.utils.TranscriptionResponse()
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider, api_base=api_base
) # type: ignore
if dynamic_api_key is not None:
api_key = dynamic_api_key
@ -4710,12 +4725,7 @@ def transcription(
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
api_key
or litellm.api_key
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
) # type: ignore
api_key = api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") # type: ignore
response = openai_audio_transcriptions.audio_transcriptions(
model=model,
audio_file=file,
@ -4802,7 +4812,9 @@ def speech(
proxy_server_request = kwargs.get("proxy_server_request", None)
extra_headers = kwargs.get("extra_headers", None)
model_info = kwargs.get("model_info", None)
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider, api_base=api_base
) # type: ignore
kwargs.pop("tags", [])
optional_params = {}
@ -4895,9 +4907,7 @@ def speech(
)
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
) # type: ignore
api_version = api_version or litellm.api_version or get_secret("AZURE_API_VERSION") # type: ignore
api_key = (
api_key
@ -5004,7 +5014,6 @@ async def ahealth_check( # noqa: PLR0915
"""
passed_in_mode: Optional[str] = None
try:
model: Optional[str] = model_params.get("model", None)
if model is None:

View file

@ -6373,6 +6373,8 @@ class ProviderConfigManager:
return litellm.OllamaConfig()
elif litellm.LlmProviders.PREDIBASE == provider:
return litellm.PredibaseConfig()
elif litellm.LlmProviders.TRITON == provider:
return litellm.TritonConfig()
return litellm.OpenAIGPTConfig()

View file

@ -322,11 +322,7 @@ def _check_provider_config(config: BaseConfig, provider: LlmProviders):
# or provider == LlmProviders.VERTEX_AI_BETA
# or provider == LlmProviders.BEDROCK
# or provider == LlmProviders.BASETEN
# or provider == LlmProviders.SAGEMAKER
# or provider == LlmProviders.SAGEMAKER_CHAT
# or provider == LlmProviders.VLLM
# or provider == LlmProviders.PETALS
# or provider == LlmProviders.OLLAMA
# ):
# continue
# config = ProviderConfigManager.get_provider_chat_config(

View file

@ -1,5 +1,5 @@
import pytest
from litellm.llms.triton import TritonChatCompletion
from litellm.llms.triton.completion.handler import TritonChatCompletion
def test_split_embedding_by_shape_passes():
@ -29,3 +29,28 @@ def test_split_embedding_by_shape_fails_with_shape_value_error():
]
with pytest.raises(ValueError):
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
def test_completion_triton():
from litellm import completion
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, MagicMock, AsyncMock
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
response = completion(
model="triton/llama-3-8b-instruct",
messages=[{"role": "user", "content": "who are u?"}],
max_tokens=10,
timeout=5,
client=client,
api_base="http://localhost:8000/generate",
)
print(response)
except Exception as e:
print(e)
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)