mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
build: Squashed commit of https://github.com/BerriAI/litellm/pull/7170
Closes https://github.com/BerriAI/litellm/pull/7170
This commit is contained in:
parent
5d1274cb6e
commit
06074bb13b
8 changed files with 197 additions and 62 deletions
|
@ -1070,6 +1070,7 @@ from .llms.anthropic.experimental_pass_through.transformation import (
|
||||||
)
|
)
|
||||||
from .llms.groq.stt.transformation import GroqSTTConfig
|
from .llms.groq.stt.transformation import GroqSTTConfig
|
||||||
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
||||||
|
from .llms.triton.completion.transformation import TritonConfig
|
||||||
from .llms.databricks.chat.transformation import DatabricksConfig
|
from .llms.databricks.chat.transformation import DatabricksConfig
|
||||||
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
||||||
from .llms.predibase.chat.transformation import PredibaseConfig
|
from .llms.predibase.chat.transformation import PredibaseConfig
|
||||||
|
|
15
litellm/llms/triton/common_utils.py
Normal file
15
litellm/llms/triton/common_utils.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
|
class TritonError(BaseLLMException):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code: int,
|
||||||
|
message: str,
|
||||||
|
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
|
@ -5,12 +5,12 @@ from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
|
_get_httpx_client,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
|
@ -24,22 +24,9 @@ from litellm.utils import (
|
||||||
map_finish_reason,
|
map_finish_reason,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseLLM
|
from ...base import BaseLLM
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
from ..common_utils import TritonError
|
||||||
|
|
||||||
class TritonError(Exception):
|
|
||||||
def __init__(self, status_code: int, message: str) -> None:
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
self.request = httpx.Request(
|
|
||||||
method="POST",
|
|
||||||
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
|
|
||||||
)
|
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
class TritonChatCompletion(BaseLLM):
|
class TritonChatCompletion(BaseLLM):
|
||||||
|
@ -142,31 +129,29 @@ class TritonChatCompletion(BaseLLM):
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[dict],
|
messages: List,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
client=None,
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
type_of_model = ""
|
type_of_model = ""
|
||||||
optional_params.pop("stream", False)
|
optional_params.pop("stream", False)
|
||||||
if api_base.endswith("generate"): ### This is a trtllm model
|
if api_base.endswith("generate"): ### This is a trtllm model
|
||||||
text_input = messages[0]["content"]
|
data_for_triton = litellm.TritonConfig().transform_request(
|
||||||
data_for_triton: Dict[str, Any] = {
|
model=model,
|
||||||
"text_input": prompt_factory(model=model, messages=messages),
|
messages=messages,
|
||||||
"parameters": {
|
optional_params=optional_params,
|
||||||
"max_tokens": int(optional_params.get("max_tokens", 2000)),
|
litellm_params=litellm_params,
|
||||||
"bad_words": [""],
|
headers=headers or {},
|
||||||
"stop_words": [""],
|
)
|
||||||
},
|
|
||||||
"stream": bool(stream),
|
|
||||||
}
|
|
||||||
data_for_triton["parameters"].update(optional_params)
|
|
||||||
type_of_model = "trtllm"
|
type_of_model = "trtllm"
|
||||||
|
|
||||||
elif api_base.endswith(
|
elif api_base.endswith(
|
||||||
|
@ -226,7 +211,13 @@ class TritonChatCompletion(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = litellm.TritonConfig().validate_environment(
|
||||||
|
headers=headers or {},
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
json_data_for_triton: str = json.dumps(data_for_triton)
|
json_data_for_triton: str = json.dumps(data_for_triton)
|
||||||
|
|
||||||
if acompletion:
|
if acompletion:
|
||||||
|
@ -240,8 +231,12 @@ class TritonChatCompletion(BaseLLM):
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
type_of_model=type_of_model,
|
type_of_model=type_of_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
handler = _get_httpx_client()
|
||||||
else:
|
else:
|
||||||
handler = HTTPHandler()
|
handler = client
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_stream( # type: ignore
|
return self._handle_stream( # type: ignore
|
||||||
handler, api_base, json_data_for_triton, model, logging_obj
|
handler, api_base, json_data_for_triton, model, logging_obj
|
92
litellm/llms/triton/completion/transformation.py
Normal file
92
litellm/llms/triton/completion/transformation.py
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
"""
|
||||||
|
Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from httpx import Headers, Response
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.transformation import (
|
||||||
|
BaseConfig,
|
||||||
|
BaseLLMException,
|
||||||
|
LiteLLMLoggingObj,
|
||||||
|
)
|
||||||
|
from litellm.llms.prompt_templates.factory import prompt_factory
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import ModelResponse
|
||||||
|
|
||||||
|
from ..common_utils import TritonError
|
||||||
|
|
||||||
|
|
||||||
|
class TritonConfig(BaseConfig):
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
inference_params = optional_params.copy()
|
||||||
|
stream = inference_params.pop("stream", False)
|
||||||
|
data_for_triton: Dict[str, Any] = {
|
||||||
|
"text_input": prompt_factory(model=model, messages=messages),
|
||||||
|
"parameters": {
|
||||||
|
"max_tokens": int(optional_params.get("max_tokens", 2000)),
|
||||||
|
"bad_words": [""],
|
||||||
|
"stop_words": [""],
|
||||||
|
},
|
||||||
|
"stream": bool(stream),
|
||||||
|
}
|
||||||
|
data_for_triton["parameters"].update(inference_params)
|
||||||
|
return data_for_triton
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return TritonError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List:
|
||||||
|
return ["max_tokens", "max_completion_tokens"]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: Dict,
|
||||||
|
optional_params: Dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> Dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens" or param == "max_completion_tokens":
|
||||||
|
optional_params[param] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: Dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
litellm_params: Dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"response transformation done in handler.py. [TODO] Migrate here."
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: Dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> Dict:
|
||||||
|
return {"Content-Type": "application/json"}
|
|
@ -126,7 +126,7 @@ from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
||||||
from .llms.sagemaker.completion.handler import SagemakerLLM
|
from .llms.sagemaker.completion.handler import SagemakerLLM
|
||||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton.completion.handler import TritonChatCompletion
|
||||||
from .llms.vertex_ai import vertex_ai_non_gemini
|
from .llms.vertex_ai import vertex_ai_non_gemini
|
||||||
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||||
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
|
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
|
||||||
|
@ -559,7 +559,9 @@ def mock_completion(
|
||||||
raise litellm.MockException(
|
raise litellm.MockException(
|
||||||
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
||||||
message=getattr(mock_response, "text", str(mock_response)),
|
message=getattr(mock_response, "text", str(mock_response)),
|
||||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
llm_provider=getattr(
|
||||||
|
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||||
|
), # type: ignore
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||||
)
|
)
|
||||||
|
@ -568,7 +570,9 @@ def mock_completion(
|
||||||
):
|
):
|
||||||
raise litellm.RateLimitError(
|
raise litellm.RateLimitError(
|
||||||
message="this is a mock rate limit error",
|
message="this is a mock rate limit error",
|
||||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
llm_provider=getattr(
|
||||||
|
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||||
|
), # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
|
@ -577,7 +581,9 @@ def mock_completion(
|
||||||
):
|
):
|
||||||
raise litellm.InternalServerError(
|
raise litellm.InternalServerError(
|
||||||
message="this is a mock internal server error",
|
message="this is a mock internal server error",
|
||||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
llm_provider=getattr(
|
||||||
|
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||||
|
), # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
elif isinstance(mock_response, str) and mock_response.startswith(
|
elif isinstance(mock_response, str) and mock_response.startswith(
|
||||||
|
@ -2374,7 +2380,6 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
return _model_response
|
return _model_response
|
||||||
response = _model_response
|
response = _model_response
|
||||||
elif custom_llm_provider == "text-completion-codestral":
|
elif custom_llm_provider == "text-completion-codestral":
|
||||||
|
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base
|
api_base
|
||||||
or optional_params.pop("api_base", None)
|
or optional_params.pop("api_base", None)
|
||||||
|
@ -2705,6 +2710,8 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
|
@ -2944,7 +2951,9 @@ def completion_with_retries(*args, **kwargs):
|
||||||
)
|
)
|
||||||
|
|
||||||
num_retries = kwargs.pop("num_retries", 3)
|
num_retries = kwargs.pop("num_retries", 3)
|
||||||
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
|
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop(
|
||||||
|
"retry_strategy", "constant_retry"
|
||||||
|
) # type: ignore
|
||||||
original_function = kwargs.pop("original_function", completion)
|
original_function = kwargs.pop("original_function", completion)
|
||||||
if retry_strategy == "exponential_backoff_retry":
|
if retry_strategy == "exponential_backoff_retry":
|
||||||
retryer = tenacity.Retrying(
|
retryer = tenacity.Retrying(
|
||||||
|
@ -3331,9 +3340,7 @@ def embedding( # noqa: PLR0915
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "databricks":
|
elif custom_llm_provider == "databricks":
|
||||||
api_base = (
|
api_base = api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") # type: ignore
|
||||||
api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE")
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
# set API KEY
|
# set API KEY
|
||||||
api_key = (
|
api_key = (
|
||||||
|
@ -3465,7 +3472,6 @@ def embedding( # noqa: PLR0915
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "gemini":
|
elif custom_llm_provider == "gemini":
|
||||||
|
|
||||||
gemini_api_key = (
|
gemini_api_key = (
|
||||||
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
|
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
|
||||||
)
|
)
|
||||||
|
@ -3960,7 +3966,11 @@ def text_completion( # noqa: PLR0915
|
||||||
optional_params["custom_llm_provider"] = custom_llm_provider
|
optional_params["custom_llm_provider"] = custom_llm_provider
|
||||||
|
|
||||||
# get custom_llm_provider
|
# get custom_llm_provider
|
||||||
_model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
_model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||||
|
model=model, # type: ignore
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
api_base=api_base,
|
||||||
|
)
|
||||||
|
|
||||||
if custom_llm_provider == "huggingface":
|
if custom_llm_provider == "huggingface":
|
||||||
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
||||||
|
@ -4212,7 +4222,6 @@ async def amoderation(
|
||||||
)
|
)
|
||||||
openai_client = kwargs.get("client", None)
|
openai_client = kwargs.get("client", None)
|
||||||
if openai_client is None or not isinstance(openai_client, AsyncOpenAI):
|
if openai_client is None or not isinstance(openai_client, AsyncOpenAI):
|
||||||
|
|
||||||
# call helper to get OpenAI client
|
# call helper to get OpenAI client
|
||||||
# _get_openai_client maintains in-memory caching logic for OpenAI clients
|
# _get_openai_client maintains in-memory caching logic for OpenAI clients
|
||||||
_openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore
|
_openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore
|
||||||
|
@ -4322,7 +4331,11 @@ def image_generation( # noqa: PLR0915
|
||||||
headers.update(extra_headers)
|
headers.update(extra_headers)
|
||||||
model_response: ImageResponse = litellm.utils.ImageResponse()
|
model_response: ImageResponse = litellm.utils.ImageResponse()
|
||||||
if model is not None or custom_llm_provider is not None:
|
if model is not None or custom_llm_provider is not None:
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||||
|
model=model, # type: ignore
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
api_base=api_base,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model = "dall-e-2"
|
model = "dall-e-2"
|
||||||
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
||||||
|
@ -4644,7 +4657,9 @@ def transcription(
|
||||||
|
|
||||||
model_response = litellm.utils.TranscriptionResponse()
|
model_response = litellm.utils.TranscriptionResponse()
|
||||||
|
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider, api_base=api_base
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
if dynamic_api_key is not None:
|
if dynamic_api_key is not None:
|
||||||
api_key = dynamic_api_key
|
api_key = dynamic_api_key
|
||||||
|
@ -4710,12 +4725,7 @@ def transcription(
|
||||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
)
|
)
|
||||||
# set API KEY
|
# set API KEY
|
||||||
api_key = (
|
api_key = api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") # type: ignore
|
||||||
api_key
|
|
||||||
or litellm.api_key
|
|
||||||
or litellm.openai_key
|
|
||||||
or get_secret("OPENAI_API_KEY")
|
|
||||||
) # type: ignore
|
|
||||||
response = openai_audio_transcriptions.audio_transcriptions(
|
response = openai_audio_transcriptions.audio_transcriptions(
|
||||||
model=model,
|
model=model,
|
||||||
audio_file=file,
|
audio_file=file,
|
||||||
|
@ -4802,7 +4812,9 @@ def speech(
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
extra_headers = kwargs.get("extra_headers", None)
|
extra_headers = kwargs.get("extra_headers", None)
|
||||||
model_info = kwargs.get("model_info", None)
|
model_info = kwargs.get("model_info", None)
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider, api_base=api_base
|
||||||
|
) # type: ignore
|
||||||
kwargs.pop("tags", [])
|
kwargs.pop("tags", [])
|
||||||
|
|
||||||
optional_params = {}
|
optional_params = {}
|
||||||
|
@ -4895,9 +4907,7 @@ def speech(
|
||||||
)
|
)
|
||||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||||
|
|
||||||
api_version = (
|
api_version = api_version or litellm.api_version or get_secret("AZURE_API_VERSION") # type: ignore
|
||||||
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
api_key
|
api_key
|
||||||
|
@ -5004,7 +5014,6 @@ async def ahealth_check( # noqa: PLR0915
|
||||||
"""
|
"""
|
||||||
passed_in_mode: Optional[str] = None
|
passed_in_mode: Optional[str] = None
|
||||||
try:
|
try:
|
||||||
|
|
||||||
model: Optional[str] = model_params.get("model", None)
|
model: Optional[str] = model_params.get("model", None)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
|
|
|
@ -6373,6 +6373,8 @@ class ProviderConfigManager:
|
||||||
return litellm.OllamaConfig()
|
return litellm.OllamaConfig()
|
||||||
elif litellm.LlmProviders.PREDIBASE == provider:
|
elif litellm.LlmProviders.PREDIBASE == provider:
|
||||||
return litellm.PredibaseConfig()
|
return litellm.PredibaseConfig()
|
||||||
|
elif litellm.LlmProviders.TRITON == provider:
|
||||||
|
return litellm.TritonConfig()
|
||||||
return litellm.OpenAIGPTConfig()
|
return litellm.OpenAIGPTConfig()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -322,11 +322,7 @@ def _check_provider_config(config: BaseConfig, provider: LlmProviders):
|
||||||
# or provider == LlmProviders.VERTEX_AI_BETA
|
# or provider == LlmProviders.VERTEX_AI_BETA
|
||||||
# or provider == LlmProviders.BEDROCK
|
# or provider == LlmProviders.BEDROCK
|
||||||
# or provider == LlmProviders.BASETEN
|
# or provider == LlmProviders.BASETEN
|
||||||
# or provider == LlmProviders.SAGEMAKER
|
|
||||||
# or provider == LlmProviders.SAGEMAKER_CHAT
|
|
||||||
# or provider == LlmProviders.VLLM
|
|
||||||
# or provider == LlmProviders.PETALS
|
# or provider == LlmProviders.PETALS
|
||||||
# or provider == LlmProviders.OLLAMA
|
|
||||||
# ):
|
# ):
|
||||||
# continue
|
# continue
|
||||||
# config = ProviderConfigManager.get_provider_chat_config(
|
# config = ProviderConfigManager.get_provider_chat_config(
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
from litellm.llms.triton import TritonChatCompletion
|
from litellm.llms.triton.completion.handler import TritonChatCompletion
|
||||||
|
|
||||||
|
|
||||||
def test_split_embedding_by_shape_passes():
|
def test_split_embedding_by_shape_passes():
|
||||||
|
@ -29,3 +29,28 @@ def test_split_embedding_by_shape_fails_with_shape_value_error():
|
||||||
]
|
]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
|
triton.split_embedding_by_shape(data[0]["data"], data[0]["shape"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_triton():
|
||||||
|
from litellm import completion
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model="triton/llama-3-8b-instruct",
|
||||||
|
messages=[{"role": "user", "content": "who are u?"}],
|
||||||
|
max_tokens=10,
|
||||||
|
timeout=5,
|
||||||
|
client=client,
|
||||||
|
api_base="http://localhost:8000/generate",
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
|
||||||
|
print(mock_post.call_args.kwargs)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue