From 06074bb13b2ac103fa4ba43701c15c44ab69b718 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 11 Dec 2024 01:03:57 -0800 Subject: [PATCH] build: Squashed commit of https://github.com/BerriAI/litellm/pull/7170 Closes https://github.com/BerriAI/litellm/pull/7170 --- litellm/__init__.py | 1 + litellm/llms/triton/common_utils.py | 15 +++ .../completion/handler.py} | 59 ++++++------ .../llms/triton/completion/transformation.py | 92 +++++++++++++++++++ litellm/main.py | 59 +++++++----- litellm/utils.py | 2 + tests/local_testing/test_config.py | 4 - tests/local_testing/test_triton.py | 27 +++++- 8 files changed, 197 insertions(+), 62 deletions(-) create mode 100644 litellm/llms/triton/common_utils.py rename litellm/llms/{triton.py => triton/completion/handler.py} (88%) create mode 100644 litellm/llms/triton/completion/transformation.py diff --git a/litellm/__init__.py b/litellm/__init__.py index f70f9d266e..adb49c739e 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1070,6 +1070,7 @@ from .llms.anthropic.experimental_pass_through.transformation import ( ) from .llms.groq.stt.transformation import GroqSTTConfig from .llms.anthropic.completion.transformation import AnthropicTextConfig +from .llms.triton.completion.transformation import TritonConfig from .llms.databricks.chat.transformation import DatabricksConfig from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig from .llms.predibase.chat.transformation import PredibaseConfig diff --git a/litellm/llms/triton/common_utils.py b/litellm/llms/triton/common_utils.py new file mode 100644 index 0000000000..64ce011b95 --- /dev/null +++ b/litellm/llms/triton/common_utils.py @@ -0,0 +1,15 @@ +from typing import Optional, Union + +import httpx + +from litellm.llms.base_llm.transformation import BaseLLMException + + +class TritonError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + headers: Optional[Union[dict, httpx.Headers]] = None, + ) -> None: + super().__init__(status_code=status_code, message=message, headers=headers) diff --git a/litellm/llms/triton.py b/litellm/llms/triton/completion/handler.py similarity index 88% rename from litellm/llms/triton.py rename to litellm/llms/triton/completion/handler.py index dd715fda44..5b3e7df0e3 100644 --- a/litellm/llms/triton.py +++ b/litellm/llms/triton/completion/handler.py @@ -5,12 +5,12 @@ from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence, Union import httpx # type: ignore -import requests # type: ignore import litellm from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, + _get_httpx_client, get_async_httpx_client, ) from litellm.utils import ( @@ -24,22 +24,9 @@ from litellm.utils import ( map_finish_reason, ) -from .base import BaseLLM -from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory - - -class TritonError(Exception): - def __init__(self, status_code: int, message: str) -> None: - self.status_code = status_code - self.message = message - self.request = httpx.Request( - method="POST", - url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url - ) - self.response = httpx.Response(status_code=status_code, request=self.request) - super().__init__( - self.message - ) # Call the base class constructor with the parameters it needs +from ...base import BaseLLM +from ...prompt_templates.factory import custom_prompt, prompt_factory +from ..common_utils import TritonError class TritonChatCompletion(BaseLLM): @@ -142,31 +129,29 @@ class TritonChatCompletion(BaseLLM): def completion( self, model: str, - messages: List[dict], + messages: List, timeout: float, api_base: str, logging_obj: Any, optional_params: dict, + litellm_params: dict, model_response: ModelResponse, api_key: Optional[str] = None, - client=None, + client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, stream: Optional[bool] = False, acompletion: bool = False, + headers: Optional[dict] = None, ) -> ModelResponse: type_of_model = "" optional_params.pop("stream", False) if api_base.endswith("generate"): ### This is a trtllm model - text_input = messages[0]["content"] - data_for_triton: Dict[str, Any] = { - "text_input": prompt_factory(model=model, messages=messages), - "parameters": { - "max_tokens": int(optional_params.get("max_tokens", 2000)), - "bad_words": [""], - "stop_words": [""], - }, - "stream": bool(stream), - } - data_for_triton["parameters"].update(optional_params) + data_for_triton = litellm.TritonConfig().transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers or {}, + ) type_of_model = "trtllm" elif api_base.endswith( @@ -226,7 +211,13 @@ class TritonChatCompletion(BaseLLM): }, ) - headers = {"Content-Type": "application/json"} + headers = litellm.TritonConfig().validate_environment( + headers=headers or {}, + model=model, + messages=messages, + optional_params=optional_params, + api_key=api_key, + ) json_data_for_triton: str = json.dumps(data_for_triton) if acompletion: @@ -240,8 +231,12 @@ class TritonChatCompletion(BaseLLM): model_response=model_response, type_of_model=type_of_model, ) + + if client is None or not isinstance(client, HTTPHandler): + handler = _get_httpx_client() else: - handler = HTTPHandler() + handler = client + if stream: return self._handle_stream( # type: ignore handler, api_base, json_data_for_triton, model, logging_obj diff --git a/litellm/llms/triton/completion/transformation.py b/litellm/llms/triton/completion/transformation.py new file mode 100644 index 0000000000..ccf920ddb9 --- /dev/null +++ b/litellm/llms/triton/completion/transformation.py @@ -0,0 +1,92 @@ +""" +Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint. +""" + +from typing import Any, Dict, List, Optional, Union + +from httpx import Headers, Response + +from litellm.llms.base_llm.transformation import ( + BaseConfig, + BaseLLMException, + LiteLLMLoggingObj, +) +from litellm.llms.prompt_templates.factory import prompt_factory +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse + +from ..common_utils import TritonError + + +class TritonConfig(BaseConfig): + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + inference_params = optional_params.copy() + stream = inference_params.pop("stream", False) + data_for_triton: Dict[str, Any] = { + "text_input": prompt_factory(model=model, messages=messages), + "parameters": { + "max_tokens": int(optional_params.get("max_tokens", 2000)), + "bad_words": [""], + "stop_words": [""], + }, + "stream": bool(stream), + } + data_for_triton["parameters"].update(inference_params) + return data_for_triton + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[Dict, Headers] + ) -> BaseLLMException: + return TritonError( + status_code=status_code, message=error_message, headers=headers + ) + + def get_supported_openai_params(self, model: str) -> List: + return ["max_tokens", "max_completion_tokens"] + + def map_openai_params( + self, + non_default_params: Dict, + optional_params: Dict, + model: str, + drop_params: bool, + ) -> Dict: + for param, value in non_default_params.items(): + if param == "max_tokens" or param == "max_completion_tokens": + optional_params[param] = value + return optional_params + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: Dict, + messages: List[AllMessageValues], + optional_params: Dict, + litellm_params: Dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + raise NotImplementedError( + "response transformation done in handler.py. [TODO] Migrate here." + ) + + def validate_environment( + self, + headers: Dict, + model: str, + messages: List[AllMessageValues], + optional_params: Dict, + api_key: Optional[str] = None, + ) -> Dict: + return {"Content-Type": "application/json"} diff --git a/litellm/main.py b/litellm/main.py index 4350689d00..f9aec9e0ab 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -126,7 +126,7 @@ from .llms.sagemaker.chat.handler import SagemakerChatHandler from .llms.sagemaker.completion.handler import SagemakerLLM from .llms.text_completion_codestral import CodestralTextCompletion from .llms.together_ai.completion.handler import TogetherAITextCompletion -from .llms.triton import TritonChatCompletion +from .llms.triton.completion.handler import TritonChatCompletion from .llms.vertex_ai import vertex_ai_non_gemini from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import ( @@ -559,7 +559,9 @@ def mock_completion( raise litellm.MockException( status_code=getattr(mock_response, "status_code", 500), # type: ignore message=getattr(mock_response, "text", str(mock_response)), - llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore + llm_provider=getattr( + mock_response, "llm_provider", custom_llm_provider or "openai" + ), # type: ignore model=model, # type: ignore request=httpx.Request(method="POST", url="https://api.openai.com/v1/"), ) @@ -568,7 +570,9 @@ def mock_completion( ): raise litellm.RateLimitError( message="this is a mock rate limit error", - llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore + llm_provider=getattr( + mock_response, "llm_provider", custom_llm_provider or "openai" + ), # type: ignore model=model, ) elif ( @@ -577,7 +581,9 @@ def mock_completion( ): raise litellm.InternalServerError( message="this is a mock internal server error", - llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore + llm_provider=getattr( + mock_response, "llm_provider", custom_llm_provider or "openai" + ), # type: ignore model=model, ) elif isinstance(mock_response, str) and mock_response.startswith( @@ -2374,7 +2380,6 @@ def completion( # type: ignore # noqa: PLR0915 return _model_response response = _model_response elif custom_llm_provider == "text-completion-codestral": - api_base = ( api_base or optional_params.pop("api_base", None) @@ -2705,6 +2710,8 @@ def completion( # type: ignore # noqa: PLR0915 logging_obj=logging, stream=stream, acompletion=acompletion, + client=client, + litellm_params=litellm_params, ) ## RESPONSE OBJECT @@ -2944,7 +2951,9 @@ def completion_with_retries(*args, **kwargs): ) num_retries = kwargs.pop("num_retries", 3) - retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore + retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop( + "retry_strategy", "constant_retry" + ) # type: ignore original_function = kwargs.pop("original_function", completion) if retry_strategy == "exponential_backoff_retry": retryer = tenacity.Retrying( @@ -3331,9 +3340,7 @@ def embedding( # noqa: PLR0915 max_retries=max_retries, ) elif custom_llm_provider == "databricks": - api_base = ( - api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") - ) # type: ignore + api_base = api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") # type: ignore # set API KEY api_key = ( @@ -3465,7 +3472,6 @@ def embedding( # noqa: PLR0915 aembedding=aembedding, ) elif custom_llm_provider == "gemini": - gemini_api_key = ( api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key ) @@ -3960,7 +3966,11 @@ def text_completion( # noqa: PLR0915 optional_params["custom_llm_provider"] = custom_llm_provider # get custom_llm_provider - _model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + _model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( + model=model, # type: ignore + custom_llm_provider=custom_llm_provider, + api_base=api_base, + ) if custom_llm_provider == "huggingface": # if echo == True, for TGI llms we need to set top_n_tokens to 3 @@ -4212,7 +4222,6 @@ async def amoderation( ) openai_client = kwargs.get("client", None) if openai_client is None or not isinstance(openai_client, AsyncOpenAI): - # call helper to get OpenAI client # _get_openai_client maintains in-memory caching logic for OpenAI clients _openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore @@ -4322,7 +4331,11 @@ def image_generation( # noqa: PLR0915 headers.update(extra_headers) model_response: ImageResponse = litellm.utils.ImageResponse() if model is not None or custom_llm_provider is not None: - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( + model=model, # type: ignore + custom_llm_provider=custom_llm_provider, + api_base=api_base, + ) else: model = "dall-e-2" custom_llm_provider = "openai" # default to dall-e-2 on openai @@ -4644,7 +4657,9 @@ def transcription( model_response = litellm.utils.TranscriptionResponse() - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( + model=model, custom_llm_provider=custom_llm_provider, api_base=api_base + ) # type: ignore if dynamic_api_key is not None: api_key = dynamic_api_key @@ -4710,12 +4725,7 @@ def transcription( or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 ) # set API KEY - api_key = ( - api_key - or litellm.api_key - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) # type: ignore + api_key = api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") # type: ignore response = openai_audio_transcriptions.audio_transcriptions( model=model, audio_file=file, @@ -4802,7 +4812,9 @@ def speech( proxy_server_request = kwargs.get("proxy_server_request", None) extra_headers = kwargs.get("extra_headers", None) model_info = kwargs.get("model_info", None) - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( + model=model, custom_llm_provider=custom_llm_provider, api_base=api_base + ) # type: ignore kwargs.pop("tags", []) optional_params = {} @@ -4895,9 +4907,7 @@ def speech( ) api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore - api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") - ) # type: ignore + api_version = api_version or litellm.api_version or get_secret("AZURE_API_VERSION") # type: ignore api_key = ( api_key @@ -5004,7 +5014,6 @@ async def ahealth_check( # noqa: PLR0915 """ passed_in_mode: Optional[str] = None try: - model: Optional[str] = model_params.get("model", None) if model is None: diff --git a/litellm/utils.py b/litellm/utils.py index a55708f94b..1772e1337b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6373,6 +6373,8 @@ class ProviderConfigManager: return litellm.OllamaConfig() elif litellm.LlmProviders.PREDIBASE == provider: return litellm.PredibaseConfig() + elif litellm.LlmProviders.TRITON == provider: + return litellm.TritonConfig() return litellm.OpenAIGPTConfig() diff --git a/tests/local_testing/test_config.py b/tests/local_testing/test_config.py index eec8be5115..2d49338539 100644 --- a/tests/local_testing/test_config.py +++ b/tests/local_testing/test_config.py @@ -322,11 +322,7 @@ def _check_provider_config(config: BaseConfig, provider: LlmProviders): # or provider == LlmProviders.VERTEX_AI_BETA # or provider == LlmProviders.BEDROCK # or provider == LlmProviders.BASETEN -# or provider == LlmProviders.SAGEMAKER -# or provider == LlmProviders.SAGEMAKER_CHAT -# or provider == LlmProviders.VLLM # or provider == LlmProviders.PETALS -# or provider == LlmProviders.OLLAMA # ): # continue # config = ProviderConfigManager.get_provider_chat_config( diff --git a/tests/local_testing/test_triton.py b/tests/local_testing/test_triton.py index 122247c8a0..2c39b16001 100644 --- a/tests/local_testing/test_triton.py +++ b/tests/local_testing/test_triton.py @@ -1,5 +1,5 @@ import pytest -from litellm.llms.triton import TritonChatCompletion +from litellm.llms.triton.completion.handler import TritonChatCompletion def test_split_embedding_by_shape_passes(): @@ -29,3 +29,28 @@ def test_split_embedding_by_shape_fails_with_shape_value_error(): ] with pytest.raises(ValueError): triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"]) + + +def test_completion_triton(): + from litellm import completion + from litellm.llms.custom_httpx.http_handler import HTTPHandler + from unittest.mock import patch, MagicMock, AsyncMock + + client = HTTPHandler() + with patch.object(client, "post") as mock_post: + try: + response = completion( + model="triton/llama-3-8b-instruct", + messages=[{"role": "user", "content": "who are u?"}], + max_tokens=10, + timeout=5, + client=client, + api_base="http://localhost:8000/generate", + ) + print(response) + except Exception as e: + print(e) + + mock_post.assert_called_once() + + print(mock_post.call_args.kwargs)