mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
build: Squashed commit of https://github.com/BerriAI/litellm/pull/7170
Closes https://github.com/BerriAI/litellm/pull/7170
This commit is contained in:
parent
5d1274cb6e
commit
06074bb13b
8 changed files with 197 additions and 62 deletions
|
@ -1070,6 +1070,7 @@ from .llms.anthropic.experimental_pass_through.transformation import (
|
|||
)
|
||||
from .llms.groq.stt.transformation import GroqSTTConfig
|
||||
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
||||
from .llms.triton.completion.transformation import TritonConfig
|
||||
from .llms.databricks.chat.transformation import DatabricksConfig
|
||||
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
||||
from .llms.predibase.chat.transformation import PredibaseConfig
|
||||
|
|
15
litellm/llms/triton/common_utils.py
Normal file
15
litellm/llms/triton/common_utils.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
|
||||
|
||||
class TritonError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||
) -> None:
|
||||
super().__init__(status_code=status_code, message=message, headers=headers)
|
|
@ -5,12 +5,12 @@ from enum import Enum
|
|||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.utils import (
|
||||
|
@ -24,22 +24,9 @@ from litellm.utils import (
|
|||
map_finish_reason,
|
||||
)
|
||||
|
||||
from .base import BaseLLM
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
||||
|
||||
|
||||
class TritonError(Exception):
|
||||
def __init__(self, status_code: int, message: str) -> None:
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST",
|
||||
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
from ...base import BaseLLM
|
||||
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
||||
from ..common_utils import TritonError
|
||||
|
||||
|
||||
class TritonChatCompletion(BaseLLM):
|
||||
|
@ -142,31 +129,29 @@ class TritonChatCompletion(BaseLLM):
|
|||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[dict],
|
||||
messages: List,
|
||||
timeout: float,
|
||||
api_base: str,
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
model_response: ModelResponse,
|
||||
api_key: Optional[str] = None,
|
||||
client=None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
acompletion: bool = False,
|
||||
headers: Optional[dict] = None,
|
||||
) -> ModelResponse:
|
||||
type_of_model = ""
|
||||
optional_params.pop("stream", False)
|
||||
if api_base.endswith("generate"): ### This is a trtllm model
|
||||
text_input = messages[0]["content"]
|
||||
data_for_triton: Dict[str, Any] = {
|
||||
"text_input": prompt_factory(model=model, messages=messages),
|
||||
"parameters": {
|
||||
"max_tokens": int(optional_params.get("max_tokens", 2000)),
|
||||
"bad_words": [""],
|
||||
"stop_words": [""],
|
||||
},
|
||||
"stream": bool(stream),
|
||||
}
|
||||
data_for_triton["parameters"].update(optional_params)
|
||||
data_for_triton = litellm.TritonConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers or {},
|
||||
)
|
||||
type_of_model = "trtllm"
|
||||
|
||||
elif api_base.endswith(
|
||||
|
@ -226,7 +211,13 @@ class TritonChatCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
headers = litellm.TritonConfig().validate_environment(
|
||||
headers=headers or {},
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
)
|
||||
json_data_for_triton: str = json.dumps(data_for_triton)
|
||||
|
||||
if acompletion:
|
||||
|
@ -240,8 +231,12 @@ class TritonChatCompletion(BaseLLM):
|
|||
model_response=model_response,
|
||||
type_of_model=type_of_model,
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
handler = _get_httpx_client()
|
||||
else:
|
||||
handler = HTTPHandler()
|
||||
handler = client
|
||||
|
||||
if stream:
|
||||
return self._handle_stream( # type: ignore
|
||||
handler, api_base, json_data_for_triton, model, logging_obj
|
92
litellm/llms/triton/completion/transformation.py
Normal file
92
litellm/llms/triton/completion/transformation.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
"""
|
||||
Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.llms.base_llm.transformation import (
|
||||
BaseConfig,
|
||||
BaseLLMException,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.llms.prompt_templates.factory import prompt_factory
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ..common_utils import TritonError
|
||||
|
||||
|
||||
class TritonConfig(BaseConfig):
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
inference_params = optional_params.copy()
|
||||
stream = inference_params.pop("stream", False)
|
||||
data_for_triton: Dict[str, Any] = {
|
||||
"text_input": prompt_factory(model=model, messages=messages),
|
||||
"parameters": {
|
||||
"max_tokens": int(optional_params.get("max_tokens", 2000)),
|
||||
"bad_words": [""],
|
||||
"stop_words": [""],
|
||||
},
|
||||
"stream": bool(stream),
|
||||
}
|
||||
data_for_triton["parameters"].update(inference_params)
|
||||
return data_for_triton
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return TritonError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return ["max_tokens", "max_completion_tokens"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: Dict,
|
||||
optional_params: Dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: Dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"response transformation done in handler.py. [TODO] Migrate here."
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Dict:
|
||||
return {"Content-Type": "application/json"}
|
|
@ -126,7 +126,7 @@ from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
|||
from .llms.sagemaker.completion.handler import SagemakerLLM
|
||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.triton.completion.handler import TritonChatCompletion
|
||||
from .llms.vertex_ai import vertex_ai_non_gemini
|
||||
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
|
||||
|
@ -559,7 +559,9 @@ def mock_completion(
|
|||
raise litellm.MockException(
|
||||
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
||||
message=getattr(mock_response, "text", str(mock_response)),
|
||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
model=model, # type: ignore
|
||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||
)
|
||||
|
@ -568,7 +570,9 @@ def mock_completion(
|
|||
):
|
||||
raise litellm.RateLimitError(
|
||||
message="this is a mock rate limit error",
|
||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
model=model,
|
||||
)
|
||||
elif (
|
||||
|
@ -577,7 +581,9 @@ def mock_completion(
|
|||
):
|
||||
raise litellm.InternalServerError(
|
||||
message="this is a mock internal server error",
|
||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
model=model,
|
||||
)
|
||||
elif isinstance(mock_response, str) and mock_response.startswith(
|
||||
|
@ -2374,7 +2380,6 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
return _model_response
|
||||
response = _model_response
|
||||
elif custom_llm_provider == "text-completion-codestral":
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or optional_params.pop("api_base", None)
|
||||
|
@ -2705,6 +2710,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
logging_obj=logging,
|
||||
stream=stream,
|
||||
acompletion=acompletion,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
|
@ -2944,7 +2951,9 @@ def completion_with_retries(*args, **kwargs):
|
|||
)
|
||||
|
||||
num_retries = kwargs.pop("num_retries", 3)
|
||||
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
|
||||
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop(
|
||||
"retry_strategy", "constant_retry"
|
||||
) # type: ignore
|
||||
original_function = kwargs.pop("original_function", completion)
|
||||
if retry_strategy == "exponential_backoff_retry":
|
||||
retryer = tenacity.Retrying(
|
||||
|
@ -3331,9 +3340,7 @@ def embedding( # noqa: PLR0915
|
|||
max_retries=max_retries,
|
||||
)
|
||||
elif custom_llm_provider == "databricks":
|
||||
api_base = (
|
||||
api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE")
|
||||
) # type: ignore
|
||||
api_base = api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") # type: ignore
|
||||
|
||||
# set API KEY
|
||||
api_key = (
|
||||
|
@ -3465,7 +3472,6 @@ def embedding( # noqa: PLR0915
|
|||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "gemini":
|
||||
|
||||
gemini_api_key = (
|
||||
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
|
||||
)
|
||||
|
@ -3960,7 +3966,11 @@ def text_completion( # noqa: PLR0915
|
|||
optional_params["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
# get custom_llm_provider
|
||||
_model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
_model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||
model=model, # type: ignore
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
if custom_llm_provider == "huggingface":
|
||||
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
||||
|
@ -4212,7 +4222,6 @@ async def amoderation(
|
|||
)
|
||||
openai_client = kwargs.get("client", None)
|
||||
if openai_client is None or not isinstance(openai_client, AsyncOpenAI):
|
||||
|
||||
# call helper to get OpenAI client
|
||||
# _get_openai_client maintains in-memory caching logic for OpenAI clients
|
||||
_openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore
|
||||
|
@ -4322,7 +4331,11 @@ def image_generation( # noqa: PLR0915
|
|||
headers.update(extra_headers)
|
||||
model_response: ImageResponse = litellm.utils.ImageResponse()
|
||||
if model is not None or custom_llm_provider is not None:
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||
model=model, # type: ignore
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
)
|
||||
else:
|
||||
model = "dall-e-2"
|
||||
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
||||
|
@ -4644,7 +4657,9 @@ def transcription(
|
|||
|
||||
model_response = litellm.utils.TranscriptionResponse()
|
||||
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||
model=model, custom_llm_provider=custom_llm_provider, api_base=api_base
|
||||
) # type: ignore
|
||||
|
||||
if dynamic_api_key is not None:
|
||||
api_key = dynamic_api_key
|
||||
|
@ -4710,12 +4725,7 @@ def transcription(
|
|||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret("OPENAI_API_KEY")
|
||||
) # type: ignore
|
||||
api_key = api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") # type: ignore
|
||||
response = openai_audio_transcriptions.audio_transcriptions(
|
||||
model=model,
|
||||
audio_file=file,
|
||||
|
@ -4802,7 +4812,9 @@ def speech(
|
|||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
extra_headers = kwargs.get("extra_headers", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||
model=model, custom_llm_provider=custom_llm_provider, api_base=api_base
|
||||
) # type: ignore
|
||||
kwargs.pop("tags", [])
|
||||
|
||||
optional_params = {}
|
||||
|
@ -4895,9 +4907,7 @@ def speech(
|
|||
)
|
||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
api_version = api_version or litellm.api_version or get_secret("AZURE_API_VERSION") # type: ignore
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
|
@ -5004,7 +5014,6 @@ async def ahealth_check( # noqa: PLR0915
|
|||
"""
|
||||
passed_in_mode: Optional[str] = None
|
||||
try:
|
||||
|
||||
model: Optional[str] = model_params.get("model", None)
|
||||
|
||||
if model is None:
|
||||
|
|
|
@ -6373,6 +6373,8 @@ class ProviderConfigManager:
|
|||
return litellm.OllamaConfig()
|
||||
elif litellm.LlmProviders.PREDIBASE == provider:
|
||||
return litellm.PredibaseConfig()
|
||||
elif litellm.LlmProviders.TRITON == provider:
|
||||
return litellm.TritonConfig()
|
||||
return litellm.OpenAIGPTConfig()
|
||||
|
||||
|
||||
|
|
|
@ -322,11 +322,7 @@ def _check_provider_config(config: BaseConfig, provider: LlmProviders):
|
|||
# or provider == LlmProviders.VERTEX_AI_BETA
|
||||
# or provider == LlmProviders.BEDROCK
|
||||
# or provider == LlmProviders.BASETEN
|
||||
# or provider == LlmProviders.SAGEMAKER
|
||||
# or provider == LlmProviders.SAGEMAKER_CHAT
|
||||
# or provider == LlmProviders.VLLM
|
||||
# or provider == LlmProviders.PETALS
|
||||
# or provider == LlmProviders.OLLAMA
|
||||
# ):
|
||||
# continue
|
||||
# config = ProviderConfigManager.get_provider_chat_config(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
from litellm.llms.triton import TritonChatCompletion
|
||||
from litellm.llms.triton.completion.handler import TritonChatCompletion
|
||||
|
||||
|
||||
def test_split_embedding_by_shape_passes():
|
||||
|
@ -29,3 +29,28 @@ def test_split_embedding_by_shape_fails_with_shape_value_error():
|
|||
]
|
||||
with pytest.raises(ValueError):
|
||||
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
|
||||
|
||||
|
||||
def test_completion_triton():
|
||||
from litellm import completion
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
client = HTTPHandler()
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
response = completion(
|
||||
model="triton/llama-3-8b-instruct",
|
||||
messages=[{"role": "user", "content": "who are u?"}],
|
||||
max_tokens=10,
|
||||
timeout=5,
|
||||
client=client,
|
||||
api_base="http://localhost:8000/generate",
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
|
||||
print(mock_post.call_args.kwargs)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue