forked from phoenix/litellm-mirror
* feat(proxy_cli.py): add new 'log_config' cli param Allows passing logging.conf to uvicorn on startup * docs(cli.md): add logging conf to uvicorn cli docs * fix(get_llm_provider_logic.py): fix default api base for litellm_proxy Fixes https://github.com/BerriAI/litellm/issues/6332 * feat(openai_like/embedding): Add support for jina ai embeddings Closes https://github.com/BerriAI/litellm/issues/6337 * docs(deploy.md): update entrypoint.sh filepath post-refactor Fixes outdated docs * feat(prometheus.py): emit time_to_first_token metric on prometheus Closes https://github.com/BerriAI/litellm/issues/6334 * fix(prometheus.py): only emit time to first token metric if stream is True enables more accurate ttft usage * test: handle vertex api instability * fix(get_llm_provider_logic.py): fix import * fix(openai.py): fix deepinfra default api base * fix(anthropic/transformation.py): remove anthropic beta header (#6361)
68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
from typing import List, Optional, Tuple
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_logger
|
|
from litellm.llms.OpenAI.openai import OpenAIConfig
|
|
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
|
|
from litellm.secret_managers.main import get_secret_str
|
|
from litellm.types.llms.openai import AllMessageValues
|
|
from litellm.types.utils import ProviderField
|
|
|
|
|
|
class AzureAIStudioConfig(OpenAIConfig):
|
|
def get_required_params(self) -> List[ProviderField]:
|
|
"""For a given provider, return it's required fields with a description"""
|
|
return [
|
|
ProviderField(
|
|
field_name="api_key",
|
|
field_type="string",
|
|
field_description="Your Azure AI Studio API Key.",
|
|
field_value="zEJ...",
|
|
),
|
|
ProviderField(
|
|
field_name="api_base",
|
|
field_type="string",
|
|
field_description="Your Azure AI Studio API Base.",
|
|
field_value="https://Mistral-serverless.",
|
|
),
|
|
]
|
|
|
|
def _transform_messages(self, messages: List[AllMessageValues]) -> List:
|
|
for message in messages:
|
|
texts = convert_content_list_to_str(message=message)
|
|
if texts:
|
|
message["content"] = texts
|
|
return messages
|
|
|
|
def _is_azure_openai_model(self, model: str) -> bool:
|
|
try:
|
|
if "/" in model:
|
|
model = model.split("/", 1)[1]
|
|
if (
|
|
model in litellm.open_ai_chat_completion_models
|
|
or model in litellm.open_ai_text_completion_models
|
|
or model in litellm.open_ai_embedding_models
|
|
):
|
|
return True
|
|
except Exception:
|
|
return False
|
|
return False
|
|
|
|
def _get_openai_compatible_provider_info(
|
|
self,
|
|
model: str,
|
|
api_base: Optional[str],
|
|
api_key: Optional[str],
|
|
custom_llm_provider: str,
|
|
) -> Tuple[Optional[str], Optional[str], str]:
|
|
api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
|
|
dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
|
|
|
|
if self._is_azure_openai_model(model=model):
|
|
verbose_logger.debug(
|
|
"Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
|
|
model
|
|
)
|
|
)
|
|
custom_llm_provider = "azure"
|
|
return api_base, dynamic_api_key, custom_llm_provider
|