mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
* test: add initial e2e test * fix(vertex_ai/files): initial commit adding sync file create support * refactor: initial commit of vertex ai non-jsonl files reaching gcp endpoint * fix(vertex_ai/files/transformation.py): initial working commit of non-jsonl file call reaching backend endpoint * fix(vertex_ai/files/transformation.py): working e2e non-jsonl file upload * test: working e2e jsonl call * test: unit testing for jsonl file creation * fix(vertex_ai/transformation.py): reset file pointer after read allow multiple reads on same file object * fix: fix linting errors * fix: fix ruff linting errors * fix: fix import * fix: fix linting error * fix: fix linting error * fix(vertex_ai/files/transformation.py): fix linting error * test: update test * test: update tests * fix: fix linting errors * fix: fix test * fix: fix linting error
149 lines
4.9 KiB
Python
149 lines
4.9 KiB
Python
import logging
|
|
import os
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
|
|
import httpx
|
|
|
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
|
|
LoggingClass = LiteLLMLoggingObj
|
|
else:
|
|
LoggingClass = Any
|
|
|
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
|
|
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
|
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BASE_URL = "https://router.huggingface.co"
|
|
|
|
|
|
class HuggingFaceChatConfig(OpenAIGPTConfig):
|
|
"""
|
|
Reference: https://huggingface.co/docs/huggingface_hub/guides/inference
|
|
"""
|
|
|
|
def validate_environment(
|
|
self,
|
|
headers: dict,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: Dict,
|
|
litellm_params: dict,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
) -> dict:
|
|
default_headers = {
|
|
"content-type": "application/json",
|
|
}
|
|
if api_key is not None:
|
|
default_headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
headers = {**headers, **default_headers}
|
|
|
|
return headers
|
|
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
|
) -> BaseLLMException:
|
|
return HuggingFaceError(
|
|
status_code=status_code, message=error_message, headers=headers
|
|
)
|
|
|
|
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
|
|
"""
|
|
Get the API base for the Huggingface API.
|
|
|
|
Do not add the chat/embedding/rerank extension here. Let the handler do this.
|
|
"""
|
|
if model.startswith(("http://", "https://")):
|
|
base_url = model
|
|
elif base_url is None:
|
|
base_url = os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE", "")
|
|
return base_url
|
|
|
|
def get_complete_url(
|
|
self,
|
|
api_base: Optional[str],
|
|
api_key: Optional[str],
|
|
model: str,
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
stream: Optional[bool] = None,
|
|
) -> str:
|
|
"""
|
|
Get the complete URL for the API call.
|
|
For provider-specific routing through huggingface
|
|
"""
|
|
# 1. Check if api_base is provided
|
|
if api_base is not None:
|
|
complete_url = api_base
|
|
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
|
|
complete_url = str(os.getenv("HF_API_BASE")) or str(
|
|
os.getenv("HUGGINGFACE_API_BASE")
|
|
)
|
|
elif model.startswith(("http://", "https://")):
|
|
complete_url = model
|
|
# 4. Default construction with provider
|
|
else:
|
|
# Parse provider and model
|
|
first_part, remaining = model.split("/", 1)
|
|
if "/" in remaining:
|
|
provider = first_part
|
|
else:
|
|
provider = "hf-inference"
|
|
|
|
if provider == "hf-inference":
|
|
route = f"{provider}/models/{model}/v1/chat/completions"
|
|
elif provider == "novita":
|
|
route = f"{provider}/chat/completions"
|
|
else:
|
|
route = f"{provider}/v1/chat/completions"
|
|
complete_url = f"{BASE_URL}/{route}"
|
|
|
|
# Ensure URL doesn't end with a slash
|
|
complete_url = complete_url.rstrip("/")
|
|
return complete_url
|
|
|
|
def transform_request(
|
|
self,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
headers: dict,
|
|
) -> dict:
|
|
if "max_retries" in optional_params:
|
|
logger.warning("`max_retries` is not supported. It will be ignored.")
|
|
optional_params.pop("max_retries", None)
|
|
first_part, remaining = model.split("/", 1)
|
|
if "/" in remaining:
|
|
provider = first_part
|
|
model_id = remaining
|
|
else:
|
|
provider = "hf-inference"
|
|
model_id = model
|
|
provider_mapping = _fetch_inference_provider_mapping(model_id)
|
|
if provider not in provider_mapping:
|
|
raise HuggingFaceError(
|
|
message=f"Model {model_id} is not supported for provider {provider}",
|
|
status_code=404,
|
|
headers={},
|
|
)
|
|
provider_mapping = provider_mapping[provider]
|
|
if provider_mapping["status"] == "staging":
|
|
logger.warning(
|
|
f"Model {model_id} is in staging mode for provider {provider}. Meant for test purposes only."
|
|
)
|
|
mapped_model = provider_mapping["providerId"]
|
|
messages = self._transform_messages(messages=messages, model=mapped_model)
|
|
return dict(
|
|
ChatCompletionRequest(
|
|
model=mapped_model, messages=messages, **optional_params
|
|
)
|
|
)
|