forked from phoenix/litellm-mirror
feat(proxy_cli.py): add new 'log_config' cli param (#6352)
* feat(proxy_cli.py): add new 'log_config' cli param Allows passing logging.conf to uvicorn on startup * docs(cli.md): add logging conf to uvicorn cli docs * fix(get_llm_provider_logic.py): fix default api base for litellm_proxy Fixes https://github.com/BerriAI/litellm/issues/6332 * feat(openai_like/embedding): Add support for jina ai embeddings Closes https://github.com/BerriAI/litellm/issues/6337 * docs(deploy.md): update entrypoint.sh filepath post-refactor Fixes outdated docs * feat(prometheus.py): emit time_to_first_token metric on prometheus Closes https://github.com/BerriAI/litellm/issues/6334 * fix(prometheus.py): only emit time to first token metric if stream is True enables more accurate ttft usage * test: handle vertex api instability * fix(get_llm_provider_logic.py): fix import * fix(openai.py): fix deepinfra default api base * fix(anthropic/transformation.py): remove anthropic beta header (#6361)
This commit is contained in:
parent
7338b24a74
commit
2b9db05e08
23 changed files with 839 additions and 263 deletions
|
@ -176,3 +176,11 @@ Cli arguments, --host, --port, --num_workers
|
|||
```
|
||||
|
||||
|
||||
## --log_config
|
||||
- **Default:** `None`
|
||||
- **Type:** `str`
|
||||
- Specify a log configuration file for uvicorn.
|
||||
- **Usage:**
|
||||
```shell
|
||||
litellm --log_config path/to/log_config.conf
|
||||
```
|
||||
|
|
|
@ -125,7 +125,7 @@ WORKDIR /app
|
|||
COPY config.yaml .
|
||||
|
||||
# Make sure your docker/entrypoint.sh is executable
|
||||
RUN chmod +x entrypoint.sh
|
||||
RUN chmod +x ./docker/entrypoint.sh
|
||||
|
||||
# Expose the necessary port
|
||||
EXPOSE 4000/tcp
|
||||
|
@ -632,7 +632,7 @@ RUN rm -rf /app/litellm/proxy/_experimental/out/* && \
|
|||
WORKDIR /app
|
||||
|
||||
# Make sure your entrypoint.sh is executable
|
||||
RUN chmod +x entrypoint.sh
|
||||
RUN chmod +x ./docker/entrypoint.sh
|
||||
|
||||
# Expose the necessary port
|
||||
EXPOSE 4000/tcp
|
||||
|
|
|
@ -134,8 +134,9 @@ Use this for LLM API Error monitoring and tracking remaining rate limits and tok
|
|||
|
||||
| Metric Name | Description |
|
||||
|----------------------|--------------------------------------|
|
||||
| `litellm_request_total_latency_metric` | Total latency (seconds) for a request to LiteLLM Proxy Server - tracked for labels `litellm_call_id`, `model`, `user_api_key`, `user_api_key_alias`, `user_api_team`, `user_api_team_alias` |
|
||||
| `litellm_llm_api_latency_metric` | Latency (seconds) for just the LLM API call - tracked for labels `litellm_call_id`, `model`, `user_api_key`, `user_api_key_alias`, `user_api_team`, `user_api_team_alias` |
|
||||
| `litellm_request_total_latency_metric` | Total latency (seconds) for a request to LiteLLM Proxy Server - tracked for labels `model`, `hashed_api_key`, `api_key_alias`, `team`, `team_alias` |
|
||||
| `litellm_llm_api_latency_metric` | Latency (seconds) for just the LLM API call - tracked for labels `model`, `hashed_api_key`, `api_key_alias`, `team`, `team_alias` |
|
||||
| `litellm_llm_api_time_to_first_token_metric` | Time to first token for LLM API call - tracked for labels `model`, `hashed_api_key`, `api_key_alias`, `team`, `team_alias` [Note: only emitted for streaming requests] |
|
||||
|
||||
## Virtual Key - Budget, Rate Limit Metrics
|
||||
|
||||
|
|
|
@ -98,6 +98,7 @@ api_key: Optional[str] = None
|
|||
openai_key: Optional[str] = None
|
||||
groq_key: Optional[str] = None
|
||||
databricks_key: Optional[str] = None
|
||||
openai_like_key: Optional[str] = None
|
||||
azure_key: Optional[str] = None
|
||||
anthropic_key: Optional[str] = None
|
||||
replicate_key: Optional[str] = None
|
||||
|
@ -710,6 +711,8 @@ model_list = (
|
|||
|
||||
class LlmProviders(str, Enum):
|
||||
OPENAI = "openai"
|
||||
OPENAI_LIKE = "openai_like" # embedding only
|
||||
JINA_AI = "jina_ai"
|
||||
CUSTOM_OPENAI = "custom_openai"
|
||||
TEXT_COMPLETION_OPENAI = "text-completion-openai"
|
||||
COHERE = "cohere"
|
||||
|
@ -1013,6 +1016,7 @@ from .llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfi
|
|||
from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
|
||||
FireworksAIEmbeddingConfig,
|
||||
)
|
||||
from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig
|
||||
from .llms.volcengine import VolcEngineConfig
|
||||
from .llms.text_completion_codestral import MistralTextCompletionConfig
|
||||
from .llms.AzureOpenAI.azure import (
|
||||
|
@ -1022,6 +1026,7 @@ from .llms.AzureOpenAI.azure import (
|
|||
|
||||
from .llms.AzureOpenAI.chat.gpt_transformation import AzureOpenAIConfig
|
||||
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
||||
from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
||||
from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config
|
||||
from .llms.watsonx import IBMWatsonXAIConfig
|
||||
from .main import * # type: ignore
|
||||
|
|
|
@ -97,6 +97,19 @@ class PrometheusLogger(CustomLogger):
|
|||
buckets=LATENCY_BUCKETS,
|
||||
)
|
||||
|
||||
self.litellm_llm_api_time_to_first_token_metric = Histogram(
|
||||
"litellm_llm_api_time_to_first_token_metric",
|
||||
"Time to first token for a models LLM API call",
|
||||
labelnames=[
|
||||
"model",
|
||||
"hashed_api_key",
|
||||
"api_key_alias",
|
||||
"team",
|
||||
"team_alias",
|
||||
],
|
||||
buckets=LATENCY_BUCKETS,
|
||||
)
|
||||
|
||||
# Counter for spend
|
||||
self.litellm_spend_metric = Counter(
|
||||
"litellm_spend_metric",
|
||||
|
@ -335,14 +348,17 @@ class PrometheusLogger(CustomLogger):
|
|||
)
|
||||
|
||||
# unpack kwargs
|
||||
standard_logging_payload: StandardLoggingPayload = kwargs.get(
|
||||
"standard_logging_object", {}
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_object is required")
|
||||
model = kwargs.get("model", "")
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
_metadata = litellm_params.get("metadata", {})
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
model_parameters: dict = standard_logging_payload["model_parameters"]
|
||||
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
||||
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
||||
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
|
||||
|
@ -468,6 +484,28 @@ class PrometheusLogger(CustomLogger):
|
|||
total_time_seconds = total_time.total_seconds()
|
||||
api_call_start_time = kwargs.get("api_call_start_time", None)
|
||||
|
||||
completion_start_time = kwargs.get("completion_start_time", None)
|
||||
|
||||
if (
|
||||
completion_start_time is not None
|
||||
and isinstance(completion_start_time, datetime)
|
||||
and model_parameters.get("stream")
|
||||
is True # only emit for streaming requests
|
||||
):
|
||||
time_to_first_token_seconds = (
|
||||
completion_start_time - api_call_start_time
|
||||
).total_seconds()
|
||||
self.litellm_llm_api_time_to_first_token_metric.labels(
|
||||
model,
|
||||
user_api_key,
|
||||
user_api_key_alias,
|
||||
user_api_team,
|
||||
user_api_team_alias,
|
||||
).observe(time_to_first_token_seconds)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
"Time to first token metric not emitted, stream option in model_parameters is not True"
|
||||
)
|
||||
if api_call_start_time is not None and isinstance(
|
||||
api_call_start_time, datetime
|
||||
):
|
||||
|
@ -512,6 +550,7 @@ class PrometheusLogger(CustomLogger):
|
|||
"standard_logging_object", {}
|
||||
)
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
||||
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
||||
|
|
|
@ -4,7 +4,7 @@ import httpx
|
|||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||
|
||||
from ..types.router import LiteLLM_Params
|
||||
|
||||
|
@ -22,21 +22,6 @@ def _is_non_openai_azure_model(model: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _is_azure_openai_model(model: str) -> bool:
|
||||
try:
|
||||
if "/" in model:
|
||||
model = model.split("/", 1)[1]
|
||||
if (
|
||||
model in litellm.open_ai_chat_completion_models
|
||||
or model in litellm.open_ai_text_completion_models
|
||||
or model in litellm.open_ai_embedding_models
|
||||
):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def handle_cohere_chat_model_custom_llm_provider(
|
||||
model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
|
@ -116,7 +101,7 @@ def get_llm_provider( # noqa: PLR0915
|
|||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
if api_key and api_key.startswith("os.environ/"):
|
||||
dynamic_api_key = get_secret(api_key)
|
||||
dynamic_api_key = get_secret_str(api_key)
|
||||
# check if llm provider part of model name
|
||||
if (
|
||||
model.split("/", 1)[0] in litellm.provider_list
|
||||
|
@ -124,204 +109,12 @@ def get_llm_provider( # noqa: PLR0915
|
|||
and len(model.split("/"))
|
||||
> 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351
|
||||
):
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
if custom_llm_provider == "perplexity":
|
||||
# perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai
|
||||
api_base = api_base or get_secret("PERPLEXITY_API_BASE") or "https://api.perplexity.ai" # type: ignore
|
||||
dynamic_api_key = (
|
||||
api_key
|
||||
or get_secret("PERPLEXITYAI_API_KEY")
|
||||
or get_secret("PERPLEXITY_API_KEY")
|
||||
return _get_openai_compatible_provider_info(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
dynamic_api_key=dynamic_api_key,
|
||||
)
|
||||
elif custom_llm_provider == "anyscale":
|
||||
# anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
||||
api_base = api_base or get_secret("ANYSCALE_API_BASE") or "https://api.endpoints.anyscale.com/v1" # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("ANYSCALE_API_KEY")
|
||||
elif custom_llm_provider == "deepinfra":
|
||||
# deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
||||
api_base = api_base or get_secret("DEEPINFRA_API_BASE") or "https://api.deepinfra.com/v1/openai" # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("DEEPINFRA_API_KEY")
|
||||
elif custom_llm_provider == "empower":
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("EMPOWER_API_BASE")
|
||||
or "https://app.empower.dev/api/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("EMPOWER_API_KEY")
|
||||
elif custom_llm_provider == "groq":
|
||||
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("GROQ_API_BASE")
|
||||
or "https://api.groq.com/openai/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("GROQ_API_KEY")
|
||||
elif custom_llm_provider == "nvidia_nim":
|
||||
# nvidia_nim is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("NVIDIA_NIM_API_BASE")
|
||||
or "https://integrate.api.nvidia.com/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("NVIDIA_NIM_API_KEY")
|
||||
elif custom_llm_provider == "cerebras":
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("CEREBRAS_API_BASE")
|
||||
or "https://api.cerebras.ai/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("CEREBRAS_API_KEY")
|
||||
elif custom_llm_provider == "sambanova":
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("SAMBANOVA_API_BASE")
|
||||
or "https://api.sambanova.ai/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("SAMBANOVA_API_KEY")
|
||||
elif (custom_llm_provider == "ai21_chat") or (
|
||||
custom_llm_provider == "ai21" and model in litellm.ai21_chat_models
|
||||
):
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("AI21_API_BASE")
|
||||
or "https://api.ai21.com/studio/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("AI21_API_KEY")
|
||||
custom_llm_provider = "ai21_chat"
|
||||
elif custom_llm_provider == "volcengine":
|
||||
# volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("VOLCENGINE_API_BASE")
|
||||
or "https://ark.cn-beijing.volces.com/api/v3"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("VOLCENGINE_API_KEY")
|
||||
elif custom_llm_provider == "codestral":
|
||||
# codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("CODESTRAL_API_BASE")
|
||||
or "https://codestral.mistral.ai/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("CODESTRAL_API_KEY")
|
||||
elif custom_llm_provider == "hosted_vllm":
|
||||
# vllm is openai compatible, we just need to set this to custom_openai
|
||||
api_base = api_base or get_secret(
|
||||
"HOSTED_VLLM_API_BASE"
|
||||
) # type: ignore
|
||||
dynamic_api_key = (
|
||||
api_key or get_secret("HOSTED_VLLM_API_KEY") or ""
|
||||
) # vllm does not require an api key
|
||||
elif custom_llm_provider == "deepseek":
|
||||
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("DEEPSEEK_API_BASE")
|
||||
or "https://api.deepseek.com/beta"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("DEEPSEEK_API_KEY")
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
|
||||
if litellm.FireworksAIEmbeddingConfig().is_fireworks_embedding_model(
|
||||
model=model
|
||||
):
|
||||
# fireworks embeddings models do no require accounts/fireworks prefix https://docs.fireworks.ai/api-reference/creates-an-embedding-vector-representing-the-input-text
|
||||
pass
|
||||
elif not model.startswith("accounts/"):
|
||||
model = f"accounts/fireworks/models/{model}"
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("FIREWORKS_API_BASE")
|
||||
or "https://api.fireworks.ai/inference/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or (
|
||||
get_secret("FIREWORKS_API_KEY")
|
||||
or get_secret("FIREWORKS_AI_API_KEY")
|
||||
or get_secret("FIREWORKSAI_API_KEY")
|
||||
or get_secret("FIREWORKS_AI_TOKEN")
|
||||
)
|
||||
elif custom_llm_provider == "azure_ai":
|
||||
api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY")
|
||||
|
||||
if _is_azure_openai_model(model=model):
|
||||
verbose_logger.debug(
|
||||
"Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
|
||||
model
|
||||
)
|
||||
)
|
||||
custom_llm_provider = "azure"
|
||||
elif custom_llm_provider == "github":
|
||||
api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("GITHUB_API_KEY")
|
||||
elif custom_llm_provider == "litellm_proxy":
|
||||
api_base = api_base or get_secret("LITELLM_PROXY_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("LITELLM_PROXY_API_KEY")
|
||||
|
||||
elif custom_llm_provider == "mistral":
|
||||
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("MISTRAL_AZURE_API_BASE") # for Azure AI Mistral
|
||||
or "https://api.mistral.ai/v1"
|
||||
) # type: ignore
|
||||
|
||||
# if api_base does not end with /v1 we add it
|
||||
if api_base is not None and not api_base.endswith(
|
||||
"/v1"
|
||||
): # Mistral always needs a /v1 at the end
|
||||
api_base = api_base + "/v1"
|
||||
dynamic_api_key = (
|
||||
api_key
|
||||
or get_secret("MISTRAL_AZURE_API_KEY") # for Azure AI Mistral
|
||||
or get_secret("MISTRAL_API_KEY")
|
||||
)
|
||||
elif custom_llm_provider == "voyage":
|
||||
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("VOYAGE_API_BASE")
|
||||
or "https://api.voyageai.com/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("VOYAGE_API_KEY")
|
||||
elif custom_llm_provider == "together_ai":
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("TOGETHER_AI_API_BASE")
|
||||
or "https://api.together.xyz/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or (
|
||||
get_secret("TOGETHER_API_KEY")
|
||||
or get_secret("TOGETHER_AI_API_KEY")
|
||||
or get_secret("TOGETHERAI_API_KEY")
|
||||
or get_secret("TOGETHER_AI_TOKEN")
|
||||
)
|
||||
elif custom_llm_provider == "friendliai":
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("FRIENDLI_API_BASE")
|
||||
or "https://inference.friendli.ai/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = (
|
||||
api_key
|
||||
or get_secret("FRIENDLIAI_API_KEY")
|
||||
or get_secret("FRIENDLI_TOKEN")
|
||||
)
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
"api base needs to be a string. api_base={}".format(api_base)
|
||||
)
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
if dynamic_api_key is None and api_key is not None:
|
||||
dynamic_api_key = api_key
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
elif model.split("/", 1)[0] in litellm.provider_list:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
|
@ -342,46 +135,46 @@ def get_llm_provider( # noqa: PLR0915
|
|||
if endpoint in api_base:
|
||||
if endpoint == "api.perplexity.ai":
|
||||
custom_llm_provider = "perplexity"
|
||||
dynamic_api_key = get_secret("PERPLEXITYAI_API_KEY")
|
||||
dynamic_api_key = get_secret_str("PERPLEXITYAI_API_KEY")
|
||||
elif endpoint == "api.endpoints.anyscale.com/v1":
|
||||
custom_llm_provider = "anyscale"
|
||||
dynamic_api_key = get_secret("ANYSCALE_API_KEY")
|
||||
dynamic_api_key = get_secret_str("ANYSCALE_API_KEY")
|
||||
elif endpoint == "api.deepinfra.com/v1/openai":
|
||||
custom_llm_provider = "deepinfra"
|
||||
dynamic_api_key = get_secret("DEEPINFRA_API_KEY")
|
||||
dynamic_api_key = get_secret_str("DEEPINFRA_API_KEY")
|
||||
elif endpoint == "api.mistral.ai/v1":
|
||||
custom_llm_provider = "mistral"
|
||||
dynamic_api_key = get_secret("MISTRAL_API_KEY")
|
||||
dynamic_api_key = get_secret_str("MISTRAL_API_KEY")
|
||||
elif endpoint == "api.groq.com/openai/v1":
|
||||
custom_llm_provider = "groq"
|
||||
dynamic_api_key = get_secret("GROQ_API_KEY")
|
||||
dynamic_api_key = get_secret_str("GROQ_API_KEY")
|
||||
elif endpoint == "https://integrate.api.nvidia.com/v1":
|
||||
custom_llm_provider = "nvidia_nim"
|
||||
dynamic_api_key = get_secret("NVIDIA_NIM_API_KEY")
|
||||
dynamic_api_key = get_secret_str("NVIDIA_NIM_API_KEY")
|
||||
elif endpoint == "https://api.cerebras.ai/v1":
|
||||
custom_llm_provider = "cerebras"
|
||||
dynamic_api_key = get_secret("CEREBRAS_API_KEY")
|
||||
dynamic_api_key = get_secret_str("CEREBRAS_API_KEY")
|
||||
elif endpoint == "https://api.sambanova.ai/v1":
|
||||
custom_llm_provider = "sambanova"
|
||||
dynamic_api_key = get_secret("SAMBANOVA_API_KEY")
|
||||
dynamic_api_key = get_secret_str("SAMBANOVA_API_KEY")
|
||||
elif endpoint == "https://api.ai21.com/studio/v1":
|
||||
custom_llm_provider = "ai21_chat"
|
||||
dynamic_api_key = get_secret("AI21_API_KEY")
|
||||
dynamic_api_key = get_secret_str("AI21_API_KEY")
|
||||
elif endpoint == "https://codestral.mistral.ai/v1":
|
||||
custom_llm_provider = "codestral"
|
||||
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
|
||||
dynamic_api_key = get_secret_str("CODESTRAL_API_KEY")
|
||||
elif endpoint == "https://codestral.mistral.ai/v1":
|
||||
custom_llm_provider = "text-completion-codestral"
|
||||
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
|
||||
dynamic_api_key = get_secret_str("CODESTRAL_API_KEY")
|
||||
elif endpoint == "app.empower.dev/api/v1":
|
||||
custom_llm_provider = "empower"
|
||||
dynamic_api_key = get_secret("EMPOWER_API_KEY")
|
||||
dynamic_api_key = get_secret_str("EMPOWER_API_KEY")
|
||||
elif endpoint == "api.deepseek.com/v1":
|
||||
custom_llm_provider = "deepseek"
|
||||
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
|
||||
dynamic_api_key = get_secret_str("DEEPSEEK_API_KEY")
|
||||
elif endpoint == "inference.friendli.ai/v1":
|
||||
custom_llm_provider = "friendliai"
|
||||
dynamic_api_key = get_secret(
|
||||
dynamic_api_key = get_secret_str(
|
||||
"FRIENDLIAI_API_KEY"
|
||||
) or get_secret("FRIENDLI_TOKEN")
|
||||
|
||||
|
@ -485,7 +278,7 @@ def get_llm_provider( # noqa: PLR0915
|
|||
custom_llm_provider = "empower"
|
||||
elif model == "*":
|
||||
custom_llm_provider = "openai"
|
||||
if custom_llm_provider is None or custom_llm_provider == "":
|
||||
if not custom_llm_provider:
|
||||
if litellm.suppress_debug_info is False:
|
||||
print() # noqa
|
||||
print( # noqa
|
||||
|
@ -532,3 +325,192 @@ def get_llm_provider( # noqa: PLR0915
|
|||
),
|
||||
llm_provider="",
|
||||
)
|
||||
|
||||
|
||||
def _get_openai_compatible_provider_info( # noqa: PLR0915
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
dynamic_api_key: Optional[str],
|
||||
) -> Tuple[str, str, Optional[str], Optional[str]]:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
if custom_llm_provider == "perplexity":
|
||||
# perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai
|
||||
(
|
||||
api_base,
|
||||
dynamic_api_key,
|
||||
) = litellm.PerplexityChatConfig()._get_openai_compatible_provider_info(
|
||||
api_base, api_key
|
||||
)
|
||||
elif custom_llm_provider == "anyscale":
|
||||
# anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
||||
api_base = api_base or get_secret_str("ANYSCALE_API_BASE") or "https://api.endpoints.anyscale.com/v1" # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("ANYSCALE_API_KEY")
|
||||
elif custom_llm_provider == "deepinfra":
|
||||
(
|
||||
api_base,
|
||||
dynamic_api_key,
|
||||
) = litellm.DeepInfraConfig()._get_openai_compatible_provider_info(
|
||||
api_base, api_key
|
||||
)
|
||||
elif custom_llm_provider == "empower":
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("EMPOWER_API_BASE")
|
||||
or "https://app.empower.dev/api/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("EMPOWER_API_KEY")
|
||||
elif custom_llm_provider == "groq":
|
||||
(
|
||||
api_base,
|
||||
dynamic_api_key,
|
||||
) = litellm.GroqChatConfig()._get_openai_compatible_provider_info(
|
||||
api_base, api_key
|
||||
)
|
||||
elif custom_llm_provider == "nvidia_nim":
|
||||
# nvidia_nim is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("NVIDIA_NIM_API_BASE")
|
||||
or "https://integrate.api.nvidia.com/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("NVIDIA_NIM_API_KEY")
|
||||
elif custom_llm_provider == "cerebras":
|
||||
api_base = (
|
||||
api_base or get_secret("CEREBRAS_API_BASE") or "https://api.cerebras.ai/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("CEREBRAS_API_KEY")
|
||||
elif custom_llm_provider == "sambanova":
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("SAMBANOVA_API_BASE")
|
||||
or "https://api.sambanova.ai/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("SAMBANOVA_API_KEY")
|
||||
elif (custom_llm_provider == "ai21_chat") or (
|
||||
custom_llm_provider == "ai21" and model in litellm.ai21_chat_models
|
||||
):
|
||||
api_base = (
|
||||
api_base or get_secret("AI21_API_BASE") or "https://api.ai21.com/studio/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("AI21_API_KEY")
|
||||
custom_llm_provider = "ai21_chat"
|
||||
elif custom_llm_provider == "volcengine":
|
||||
# volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("VOLCENGINE_API_BASE")
|
||||
or "https://ark.cn-beijing.volces.com/api/v3"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("VOLCENGINE_API_KEY")
|
||||
elif custom_llm_provider == "codestral":
|
||||
# codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("CODESTRAL_API_BASE")
|
||||
or "https://codestral.mistral.ai/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("CODESTRAL_API_KEY")
|
||||
elif custom_llm_provider == "hosted_vllm":
|
||||
# vllm is openai compatible, we just need to set this to custom_openai
|
||||
(
|
||||
api_base,
|
||||
dynamic_api_key,
|
||||
) = litellm.HostedVLLMChatConfig()._get_openai_compatible_provider_info(
|
||||
api_base, api_key
|
||||
)
|
||||
elif custom_llm_provider == "deepseek":
|
||||
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("DEEPSEEK_API_BASE")
|
||||
or "https://api.deepseek.com/beta"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("DEEPSEEK_API_KEY")
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
|
||||
(
|
||||
model,
|
||||
api_base,
|
||||
dynamic_api_key,
|
||||
) = litellm.FireworksAIConfig()._get_openai_compatible_provider_info(
|
||||
model, api_base, api_key
|
||||
)
|
||||
elif custom_llm_provider == "azure_ai":
|
||||
(
|
||||
api_base,
|
||||
dynamic_api_key,
|
||||
custom_llm_provider,
|
||||
) = litellm.AzureAIStudioConfig()._get_openai_compatible_provider_info(
|
||||
model, api_base, api_key, custom_llm_provider
|
||||
)
|
||||
elif custom_llm_provider == "github":
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("GITHUB_API_BASE")
|
||||
or "https://models.inference.ai.azure.com" # This is github's default base url
|
||||
)
|
||||
dynamic_api_key = api_key or get_secret_str("GITHUB_API_KEY")
|
||||
elif custom_llm_provider == "litellm_proxy":
|
||||
api_base = api_base or get_secret_str("LITELLM_PROXY_API_BASE")
|
||||
dynamic_api_key = api_key or get_secret_str("LITELLM_PROXY_API_KEY")
|
||||
|
||||
elif custom_llm_provider == "mistral":
|
||||
(
|
||||
api_base,
|
||||
dynamic_api_key,
|
||||
) = litellm.MistralConfig()._get_openai_compatible_provider_info(
|
||||
api_base, api_key
|
||||
)
|
||||
elif custom_llm_provider == "jina_ai":
|
||||
(
|
||||
custom_llm_provider,
|
||||
api_base,
|
||||
dynamic_api_key,
|
||||
) = litellm.JinaAIEmbeddingConfig()._get_openai_compatible_provider_info(
|
||||
api_base, api_key
|
||||
)
|
||||
elif custom_llm_provider == "voyage":
|
||||
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("VOYAGE_API_BASE")
|
||||
or "https://api.voyageai.com/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("VOYAGE_API_KEY")
|
||||
elif custom_llm_provider == "together_ai":
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("TOGETHER_AI_API_BASE")
|
||||
or "https://api.together.xyz/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or (
|
||||
get_secret_str("TOGETHER_API_KEY")
|
||||
or get_secret_str("TOGETHER_AI_API_KEY")
|
||||
or get_secret_str("TOGETHERAI_API_KEY")
|
||||
or get_secret_str("TOGETHER_AI_TOKEN")
|
||||
)
|
||||
elif custom_llm_provider == "friendliai":
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("FRIENDLI_API_BASE")
|
||||
or "https://inference.friendli.ai/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = (
|
||||
api_key
|
||||
or get_secret_str("FRIENDLIAI_API_KEY")
|
||||
or get_secret_str("FRIENDLI_TOKEN")
|
||||
)
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception("api base needs to be a string. api_base={}".format(api_base))
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
if dynamic_api_key is None and api_key is not None:
|
||||
dynamic_api_key = api_key
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
|
|
@ -17,6 +17,7 @@ from typing_extensions import overload, override
|
|||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.utils import ProviderField
|
||||
from litellm.utils import (
|
||||
Choices,
|
||||
|
@ -221,6 +222,18 @@ class DeepInfraConfig:
|
|||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
# deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("DEEPINFRA_API_BASE")
|
||||
or "https://api.deepinfra.com/v1/openai"
|
||||
)
|
||||
dynamic_api_key = api_key or get_secret_str("DEEPINFRA_API_KEY")
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
|
||||
class OpenAIConfig:
|
||||
"""
|
||||
|
|
|
@ -274,9 +274,6 @@ class AnthropicConfig:
|
|||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
if "anthropic-beta" not in headers:
|
||||
# default to v1 of "anthropic-beta"
|
||||
headers["anthropic-beta"] = "tools-2024-05-16"
|
||||
anthropic_tools = []
|
||||
for tool in optional_params["tools"]:
|
||||
if "input_schema" in tool: # assume in anthropic format
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from typing import List
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.OpenAI.openai import OpenAIConfig
|
||||
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ProviderField
|
||||
|
||||
|
@ -30,3 +33,36 @@ class AzureAIStudioConfig(OpenAIConfig):
|
|||
if texts:
|
||||
message["content"] = texts
|
||||
return messages
|
||||
|
||||
def _is_azure_openai_model(self, model: str) -> bool:
|
||||
try:
|
||||
if "/" in model:
|
||||
model = model.split("/", 1)[1]
|
||||
if (
|
||||
model in litellm.open_ai_chat_completion_models
|
||||
or model in litellm.open_ai_text_completion_models
|
||||
or model in litellm.open_ai_embedding_models
|
||||
):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
custom_llm_provider: str,
|
||||
) -> Tuple[Optional[str], Optional[str], str]:
|
||||
api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
|
||||
dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
|
||||
|
||||
if self._is_azure_openai_model(model=model):
|
||||
verbose_logger.debug(
|
||||
"Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
|
||||
model
|
||||
)
|
||||
)
|
||||
custom_llm_provider = "azure"
|
||||
return api_base, dynamic_api_key, custom_llm_provider
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
import types
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Optional, Tuple, Union
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
from ..embed.fireworks_ai_transformation import FireworksAIEmbeddingConfig
|
||||
|
||||
|
||||
class FireworksAIConfig:
|
||||
|
@ -107,3 +111,24 @@ class FireworksAIConfig:
|
|||
if value is not None:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, model: str, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[str, Optional[str], Optional[str]]:
|
||||
if FireworksAIEmbeddingConfig().is_fireworks_embedding_model(model=model):
|
||||
# fireworks embeddings models do not require accounts/fireworks prefix https://docs.fireworks.ai/api-reference/creates-an-embedding-vector-representing-the-input-text
|
||||
pass
|
||||
elif not model.startswith("accounts/"):
|
||||
model = f"accounts/fireworks/models/{model}"
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("FIREWORKS_API_BASE")
|
||||
or "https://api.fireworks.ai/inference/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or (
|
||||
get_secret_str("FIREWORKS_API_KEY")
|
||||
or get_secret_str("FIREWORKS_AI_API_KEY")
|
||||
or get_secret_str("FIREWORKSAI_API_KEY")
|
||||
or get_secret_str("FIREWORKS_AI_TOKEN")
|
||||
)
|
||||
return model, api_base, dynamic_api_key
|
||||
|
|
|
@ -3,11 +3,12 @@ Translate from OpenAI's `/v1/chat/completions` to Groq's `/v1/chat/completions`
|
|||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
|
||||
|
||||
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
@ -86,3 +87,15 @@ class GroqChatConfig(OpenAIGPTConfig):
|
|||
messages[idx] = new_message
|
||||
|
||||
return messages
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("GROQ_API_BASE")
|
||||
or "https://api.groq.com/openai/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("GROQ_API_KEY")
|
||||
return api_base, dynamic_api_key
|
||||
|
|
|
@ -3,11 +3,12 @@ Translate from OpenAI's `/v1/chat/completions` to VLLM's `/v1/chat/completions`
|
|||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
|
||||
|
||||
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
||||
|
@ -33,3 +34,12 @@ class HostedVLLMChatConfig(OpenAIGPTConfig):
|
|||
return super().map_openai_params(
|
||||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE") # type: ignore
|
||||
dynamic_api_key = (
|
||||
api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
|
||||
) # vllm does not require an api key
|
||||
return api_base, dynamic_api_key
|
||||
|
|
79
litellm/llms/jina_ai/embedding/transformation.py
Normal file
79
litellm/llms/jina_ai/embedding/transformation.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Jina AI's `/v1/embeddings` format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Docs - https://jina.ai/embeddings/
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from litellm import LlmProviders
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
|
||||
|
||||
class JinaAIEmbeddingConfig:
|
||||
"""
|
||||
Reference: https://jina.ai/embeddings/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return ["dimensions"]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
if "dimensions" in non_default_params:
|
||||
optional_params["dimensions"] = non_default_params["dimensions"]
|
||||
return optional_params
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
) -> Tuple[str, Optional[str], Optional[str]]:
|
||||
"""
|
||||
Returns:
|
||||
Tuple[str, Optional[str], Optional[str]]:
|
||||
- custom_llm_provider: str
|
||||
- api_base: str
|
||||
- dynamic_api_key: str
|
||||
"""
|
||||
api_base = (
|
||||
api_base or get_secret_str("JINA_AI_API_BASE") or "https://api.jina.ai/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or (
|
||||
get_secret_str("JINA_AI_API_KEY")
|
||||
or get_secret_str("JINA_AI_API_KEY")
|
||||
or get_secret_str("JINA_AI_API_KEY")
|
||||
or get_secret_str("JINA_AI_TOKEN")
|
||||
)
|
||||
return LlmProviders.OPENAI_LIKE.value, api_base, dynamic_api_key
|
|
@ -7,7 +7,9 @@ Docs - https://docs.mistral.ai/api/
|
|||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
class MistralConfig:
|
||||
|
@ -124,3 +126,25 @@ class MistralConfig:
|
|||
if param == "response_format":
|
||||
optional_params["response_format"] = value
|
||||
return optional_params
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("MISTRAL_AZURE_API_BASE") # for Azure AI Mistral
|
||||
or "https://api.mistral.ai/v1"
|
||||
) # type: ignore
|
||||
|
||||
# if api_base does not end with /v1 we add it
|
||||
if api_base is not None and not api_base.endswith(
|
||||
"/v1"
|
||||
): # Mistral always needs a /v1 at the end
|
||||
api_base = api_base + "/v1"
|
||||
dynamic_api_key = (
|
||||
api_key
|
||||
or get_secret_str("MISTRAL_AZURE_API_KEY") # for Azure AI Mistral
|
||||
or get_secret_str("MISTRAL_API_KEY")
|
||||
)
|
||||
return api_base, dynamic_api_key
|
||||
|
|
12
litellm/llms/openai_like/common_utils.py
Normal file
12
litellm/llms/openai_like/common_utils.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
import httpx
|
||||
|
||||
|
||||
class OpenAILikeError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://www.litellm.ai")
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
190
litellm/llms/openai_like/embedding/handler.py
Normal file
190
litellm/llms/openai_like/embedding/handler.py
Normal file
|
@ -0,0 +1,190 @@
|
|||
# What is this?
|
||||
## Handler file for OpenAI-like endpoints.
|
||||
## Allows jina ai embedding calls - which don't allow 'encoding_format' in payload.
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.utils import EmbeddingResponse
|
||||
|
||||
from ..common_utils import OpenAILikeError
|
||||
|
||||
|
||||
class OpenAILikeEmbeddingHandler:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
endpoint_type: Literal["chat_completions", "embeddings"],
|
||||
headers: Optional[dict],
|
||||
) -> Tuple[str, dict]:
|
||||
if api_key is None and headers is None:
|
||||
raise OpenAILikeError(
|
||||
status_code=400,
|
||||
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise OpenAILikeError(
|
||||
status_code=400,
|
||||
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
|
||||
if headers is None:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if api_key is not None:
|
||||
headers.update({"Authorization": "Bearer {}".format(api_key)})
|
||||
|
||||
if endpoint_type == "chat_completions":
|
||||
api_base = "{}/chat/completions".format(api_base)
|
||||
elif endpoint_type == "embeddings":
|
||||
api_base = "{}/embeddings".format(api_base)
|
||||
return api_base, headers
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
input: list,
|
||||
data: dict,
|
||||
model_response: EmbeddingResponse,
|
||||
timeout: float,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
logging_obj,
|
||||
headers: dict,
|
||||
client=None,
|
||||
) -> EmbeddingResponse:
|
||||
response = None
|
||||
try:
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
self.async_client = AsyncHTTPHandler(timeout=timeout) # type: ignore
|
||||
else:
|
||||
self.async_client = client
|
||||
|
||||
try:
|
||||
response = await self.async_client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
) # type: ignore
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise OpenAILikeError(
|
||||
status_code=e.response.status_code,
|
||||
message=response.text if response else str(e),
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
raise OpenAILikeError(
|
||||
status_code=408, message="Timeout error occurred."
|
||||
)
|
||||
except Exception as e:
|
||||
raise OpenAILikeError(status_code=500, message=str(e))
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response_json,
|
||||
)
|
||||
return EmbeddingResponse(**response_json)
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: list,
|
||||
timeout: float,
|
||||
logging_obj,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
optional_params: dict,
|
||||
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> EmbeddingResponse:
|
||||
api_base, headers = self._validate_environment(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
endpoint_type="embeddings",
|
||||
headers=headers,
|
||||
)
|
||||
model = model
|
||||
data = {"model": model, "input": input, **optional_params}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data, "api_base": api_base},
|
||||
)
|
||||
|
||||
if aembedding is True:
|
||||
return self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, headers=headers) # type: ignore
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
self.client = HTTPHandler(timeout=timeout) # type: ignore
|
||||
else:
|
||||
self.client = client
|
||||
|
||||
## EMBEDDING CALL
|
||||
try:
|
||||
response = self.client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
) # type: ignore
|
||||
|
||||
response.raise_for_status() # type: ignore
|
||||
|
||||
response_json = response.json() # type: ignore
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise OpenAILikeError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
raise OpenAILikeError(status_code=408, message="Timeout error occurred.")
|
||||
except Exception as e:
|
||||
raise OpenAILikeError(status_code=500, message=str(e))
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response_json,
|
||||
)
|
||||
|
||||
return litellm.EmbeddingResponse(**response_json)
|
28
litellm/llms/perplexity/chat/transformation.py
Normal file
28
litellm/llms/perplexity/chat/transformation.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
"""
|
||||
Translate from OpenAI's `/v1/chat/completions` to Perplexity's `/v1/chat/completions`
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
|
||||
|
||||
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
||||
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class PerplexityChatConfig(OpenAIGPTConfig):
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = api_base or get_secret_str("PERPLEXITY_API_BASE") or "https://api.perplexity.ai" # type: ignore
|
||||
dynamic_api_key = (
|
||||
api_key
|
||||
or get_secret_str("PERPLEXITYAI_API_KEY")
|
||||
or get_secret_str("PERPLEXITY_API_KEY")
|
||||
)
|
||||
return api_base, dynamic_api_key
|
|
@ -121,6 +121,7 @@ from .llms.huggingface_restapi import Huggingface
|
|||
from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
|
||||
from .llms.OpenAI.chat.o1_handler import OpenAIO1ChatCompletion
|
||||
from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.prompt_templates.common_utils import get_completion_messages
|
||||
from .llms.prompt_templates.factory import (
|
||||
|
@ -220,6 +221,7 @@ vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
|||
vertex_text_to_speech = VertexTextToSpeechAPI()
|
||||
watsonxai = IBMWatsonXAI()
|
||||
sagemaker_llm = SagemakerLLM()
|
||||
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -3129,6 +3131,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
|||
or custom_llm_provider == "bedrock"
|
||||
or custom_llm_provider == "azure_ai"
|
||||
or custom_llm_provider == "together_ai"
|
||||
or custom_llm_provider == "openai_like"
|
||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -3477,6 +3480,32 @@ def embedding( # noqa: PLR0915
|
|||
client=client,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "openai_like":
|
||||
api_base = (
|
||||
api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE")
|
||||
)
|
||||
|
||||
# set API KEY
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_like_key
|
||||
or get_secret_str("OPENAI_LIKE_API_KEY")
|
||||
)
|
||||
|
||||
## EMBEDDING CALL
|
||||
response = openai_like_embedding.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
timeout=timeout,
|
||||
model_response=EmbeddingResponse(),
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat":
|
||||
cohere_key = (
|
||||
api_key
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
model_list:
|
||||
- model_name: "gpt-4o-audio-preview"
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-4o-audio-preview
|
||||
model: gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["prometheus"]
|
|
@ -153,6 +153,12 @@ def is_port_in_use(port):
|
|||
type=bool,
|
||||
help="Helps us know if people are using this feature. Turn this off by doing `--telemetry False`",
|
||||
)
|
||||
@click.option(
|
||||
"--log_config",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the logging configuration file",
|
||||
)
|
||||
@click.option(
|
||||
"--version",
|
||||
"-v",
|
||||
|
@ -249,6 +255,7 @@ def run_server( # noqa: PLR0915
|
|||
run_hypercorn,
|
||||
ssl_keyfile_path,
|
||||
ssl_certfile_path,
|
||||
log_config,
|
||||
):
|
||||
args = locals()
|
||||
if local:
|
||||
|
@ -690,25 +697,26 @@ def run_server( # noqa: PLR0915
|
|||
# DO NOT DELETE - enables global variables to work across files
|
||||
from litellm.proxy.proxy_server import app # noqa
|
||||
|
||||
uvicorn_args = {
|
||||
"app": app,
|
||||
"host": host,
|
||||
"port": port,
|
||||
}
|
||||
if log_config is not None:
|
||||
print(f"Using log_config: {log_config}") # noqa
|
||||
uvicorn_args["log_config"] = log_config
|
||||
elif litellm.json_logs:
|
||||
print("Using json logs. Setting log_config to None.") # noqa
|
||||
uvicorn_args["log_config"] = None
|
||||
|
||||
if run_gunicorn is False and run_hypercorn is False:
|
||||
if ssl_certfile_path is not None and ssl_keyfile_path is not None:
|
||||
print( # noqa
|
||||
f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n" # noqa
|
||||
)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl_keyfile=ssl_keyfile_path,
|
||||
ssl_certfile=ssl_certfile_path,
|
||||
) # run uvicorn
|
||||
else:
|
||||
if litellm.json_logs:
|
||||
uvicorn.run(
|
||||
app, host=host, port=port, log_config=None
|
||||
) # run uvicorn w/ json
|
||||
else:
|
||||
uvicorn.run(app, host=host, port=port) # run uvicorn
|
||||
uvicorn_args["ssl_keyfile"] = ssl_keyfile_path
|
||||
uvicorn_args["ssl_certfile"] = ssl_certfile_path
|
||||
uvicorn.run(**uvicorn_args)
|
||||
elif run_gunicorn is True:
|
||||
# Gunicorn Application Class
|
||||
class StandaloneApplication(gunicorn.app.base.BaseApplication):
|
||||
|
|
|
@ -195,6 +195,8 @@ async def test_get_router_response():
|
|||
|
||||
print(f"\n\nResponse: {response}\n\n")
|
||||
|
||||
except litellm.ServiceUnavailableError:
|
||||
pass
|
||||
except litellm.UnprocessableEntityError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
|
|
@ -6,7 +6,8 @@ from dotenv import load_dotenv
|
|||
|
||||
load_dotenv()
|
||||
import io
|
||||
import os
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -124,3 +125,38 @@ def test_get_llm_provider_azure_o1():
|
|||
)
|
||||
assert custom_llm_provider == "azure"
|
||||
assert model == "o1-mini"
|
||||
|
||||
|
||||
def test_default_api_base():
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import (
|
||||
_get_openai_compatible_provider_info,
|
||||
)
|
||||
|
||||
# Patch environment variable to remove API base if it's set
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
for provider in litellm.openai_compatible_providers:
|
||||
# Get the API base for the given provider
|
||||
_, _, _, api_base = _get_openai_compatible_provider_info(
|
||||
model=f"{provider}/*", api_base=None, api_key=None, dynamic_api_key=None
|
||||
)
|
||||
if api_base is None:
|
||||
continue
|
||||
|
||||
for other_provider in litellm.provider_list:
|
||||
if other_provider != provider and provider != "{}_chat".format(
|
||||
other_provider.value
|
||||
):
|
||||
if provider == "codestral" and other_provider == "mistral":
|
||||
continue
|
||||
elif provider == "github" and other_provider == "azure":
|
||||
continue
|
||||
assert other_provider.value not in api_base.replace("/openai", "")
|
||||
|
||||
|
||||
def test_get_llm_provider_jina_ai():
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||
model="jina_ai/jina-embeddings-v3",
|
||||
)
|
||||
assert custom_llm_provider == "openai_like"
|
||||
assert api_base == "https://api.jina.ai/v1"
|
||||
assert model == "jina-embeddings-v3"
|
||||
|
|
36
tests/test_logging.conf
Normal file
36
tests/test_logging.conf
Normal file
|
@ -0,0 +1,36 @@
|
|||
[loggers]
|
||||
keys=root,my_module
|
||||
|
||||
[handlers]
|
||||
keys=consoleHandler,fileHandler
|
||||
|
||||
[formatters]
|
||||
keys=simpleFormatter,detailedFormatter
|
||||
|
||||
[logger_root]
|
||||
level=WARNING
|
||||
handlers=consoleHandler
|
||||
|
||||
[logger_my_module]
|
||||
level=DEBUG
|
||||
handlers=consoleHandler,fileHandler
|
||||
qualname=my_module
|
||||
propagate=0
|
||||
|
||||
[handler_consoleHandler]
|
||||
class=StreamHandler
|
||||
level=DEBUG
|
||||
formatter=simpleFormatter
|
||||
args=(sys.stdout,)
|
||||
|
||||
[handler_fileHandler]
|
||||
class=FileHandler
|
||||
level=INFO
|
||||
formatter=detailedFormatter
|
||||
args=('app.log', 'a')
|
||||
|
||||
[formatter_simpleFormatter]
|
||||
format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
|
||||
|
||||
[formatter_detailedFormatter]
|
||||
format=%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s
|
Loading…
Add table
Add a link
Reference in a new issue