mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Add pyright to ci/cd + Fix remaining type-checking errors (#6082)
* fix: fix type-checking errors * fix: fix additional type-checking errors * fix: additional type-checking error fixes * fix: fix additional type-checking errors * fix: additional type-check fixes * fix: fix all type-checking errors + add pyright to ci/cd * fix: fix incorrect import * ci(config.yml): use mypy on ci/cd * fix: fix type-checking errors in utils.py * fix: fix all type-checking errors on main.py * fix: fix mypy linting errors * fix(anthropic/cost_calculator.py): fix linting errors * fix: fix mypy linting errors * fix: fix linting errors
This commit is contained in:
parent
f7ce1173f3
commit
fac3b2ee42
65 changed files with 619 additions and 522 deletions
222
litellm/main.py
222
litellm/main.py
|
@ -19,7 +19,8 @@ import threading
|
|||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent import futures
|
||||
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Type, Union
|
||||
|
@ -647,7 +648,7 @@ def mock_completion(
|
|||
|
||||
|
||||
@client
|
||||
def completion(
|
||||
def completion( # type: ignore
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List = [],
|
||||
|
@ -2940,16 +2941,16 @@ def completion_with_retries(*args, **kwargs):
|
|||
num_retries = kwargs.pop("num_retries", 3)
|
||||
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
|
||||
original_function = kwargs.pop("original_function", completion)
|
||||
if retry_strategy == "constant_retry":
|
||||
retryer = tenacity.Retrying(
|
||||
stop=tenacity.stop_after_attempt(num_retries), reraise=True
|
||||
)
|
||||
elif retry_strategy == "exponential_backoff_retry":
|
||||
if retry_strategy == "exponential_backoff_retry":
|
||||
retryer = tenacity.Retrying(
|
||||
wait=tenacity.wait_exponential(multiplier=1, max=10),
|
||||
stop=tenacity.stop_after_attempt(num_retries),
|
||||
reraise=True,
|
||||
)
|
||||
else:
|
||||
retryer = tenacity.Retrying(
|
||||
stop=tenacity.stop_after_attempt(num_retries), reraise=True
|
||||
)
|
||||
return retryer(original_function, *args, **kwargs)
|
||||
|
||||
|
||||
|
@ -2968,16 +2969,16 @@ async def acompletion_with_retries(*args, **kwargs):
|
|||
num_retries = kwargs.pop("num_retries", 3)
|
||||
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
|
||||
original_function = kwargs.pop("original_function", completion)
|
||||
if retry_strategy == "constant_retry":
|
||||
retryer = tenacity.Retrying(
|
||||
stop=tenacity.stop_after_attempt(num_retries), reraise=True
|
||||
)
|
||||
elif retry_strategy == "exponential_backoff_retry":
|
||||
if retry_strategy == "exponential_backoff_retry":
|
||||
retryer = tenacity.Retrying(
|
||||
wait=tenacity.wait_exponential(multiplier=1, max=10),
|
||||
stop=tenacity.stop_after_attempt(num_retries),
|
||||
reraise=True,
|
||||
)
|
||||
else:
|
||||
retryer = tenacity.Retrying(
|
||||
stop=tenacity.stop_after_attempt(num_retries), reraise=True
|
||||
)
|
||||
return await retryer(original_function, *args, **kwargs)
|
||||
|
||||
|
||||
|
@ -3045,7 +3046,7 @@ def batch_completion(
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stream=stream,
|
||||
stream=stream or False,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
|
@ -3124,7 +3125,7 @@ def batch_completion_models(*args, **kwargs):
|
|||
models = kwargs["models"]
|
||||
kwargs.pop("models")
|
||||
futures = {}
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
|
||||
with ThreadPoolExecutor(max_workers=len(models)) as executor:
|
||||
for model in models:
|
||||
futures[model] = executor.submit(
|
||||
completion, *args, model=model, **kwargs
|
||||
|
@ -3141,9 +3142,7 @@ def batch_completion_models(*args, **kwargs):
|
|||
kwargs.pop("model_list")
|
||||
nested_kwargs = kwargs.pop("kwargs", {})
|
||||
futures = {}
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=len(deployments)
|
||||
) as executor:
|
||||
with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
|
||||
for deployment in deployments:
|
||||
for key in kwargs.keys():
|
||||
if (
|
||||
|
@ -3156,9 +3155,7 @@ def batch_completion_models(*args, **kwargs):
|
|||
while futures:
|
||||
# wait for the first returned future
|
||||
print_verbose("\n\n waiting for next result\n\n")
|
||||
done, _ = concurrent.futures.wait(
|
||||
futures.values(), return_when=concurrent.futures.FIRST_COMPLETED
|
||||
)
|
||||
done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
|
||||
print_verbose(f"done list\n{done}")
|
||||
for future in done:
|
||||
try:
|
||||
|
@ -3214,6 +3211,8 @@ def batch_completion_models_all_responses(*args, **kwargs):
|
|||
if "models" in kwargs:
|
||||
models = kwargs["models"]
|
||||
kwargs.pop("models")
|
||||
else:
|
||||
raise Exception("'models' param not in kwargs")
|
||||
|
||||
responses = []
|
||||
|
||||
|
@ -3256,6 +3255,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
|||
model=model, api_base=kwargs.get("api_base", None)
|
||||
)
|
||||
|
||||
response: Optional[EmbeddingResponse] = None
|
||||
if (
|
||||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "azure"
|
||||
|
@ -3294,12 +3294,21 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
|||
elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO
|
||||
response = init_response
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
response = await init_response # type: ignore
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if response is not None and hasattr(response, "_hidden_params"):
|
||||
if (
|
||||
response is not None
|
||||
and isinstance(response, EmbeddingResponse)
|
||||
and hasattr(response, "_hidden_params")
|
||||
):
|
||||
response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"Unable to get Embedding Response. Please pass a valid llm_provider."
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
custom_llm_provider = custom_llm_provider or "openai"
|
||||
|
@ -3329,7 +3338,6 @@ def embedding(
|
|||
user: Optional[str] = None,
|
||||
custom_llm_provider=None,
|
||||
litellm_call_id=None,
|
||||
litellm_logging_obj=None,
|
||||
logger_fn=None,
|
||||
**kwargs,
|
||||
) -> EmbeddingResponse:
|
||||
|
@ -3362,6 +3370,7 @@ def embedding(
|
|||
client = kwargs.pop("client", None)
|
||||
rpm = kwargs.pop("rpm", None)
|
||||
tpm = kwargs.pop("tpm", None)
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
cooldown_time = kwargs.get("cooldown_time", None)
|
||||
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
|
@ -3491,7 +3500,7 @@ def embedding(
|
|||
}
|
||||
)
|
||||
try:
|
||||
response = None
|
||||
response: Optional[EmbeddingResponse] = None
|
||||
logging: Logging = litellm_logging_obj # type: ignore
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -3691,7 +3700,7 @@ def embedding(
|
|||
raise ValueError(
|
||||
"api_base is required for triton. Please pass `api_base`"
|
||||
)
|
||||
response = triton_chat_completions.embedding(
|
||||
response = triton_chat_completions.embedding( # type: ignore
|
||||
model=model,
|
||||
input=input,
|
||||
api_base=api_base,
|
||||
|
@ -3783,6 +3792,7 @@ def embedding(
|
|||
timeout=timeout,
|
||||
aembedding=aembedding,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
)
|
||||
elif custom_llm_provider == "oobabooga":
|
||||
response = oobabooga.embedding(
|
||||
|
@ -3793,14 +3803,16 @@ def embedding(
|
|||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
api_key=api_key,
|
||||
)
|
||||
elif custom_llm_provider == "ollama":
|
||||
api_base = (
|
||||
litellm.api_base
|
||||
or api_base
|
||||
or get_secret("OLLAMA_API_BASE")
|
||||
or get_secret_str("OLLAMA_API_BASE")
|
||||
or "http://localhost:11434"
|
||||
) # type: ignore
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if not all(isinstance(item, str) for item in input):
|
||||
|
@ -3881,13 +3893,13 @@ def embedding(
|
|||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or get_secret("XINFERENCE_API_KEY")
|
||||
or get_secret_str("XINFERENCE_API_KEY")
|
||||
or "stub-xinference-key"
|
||||
) # xinference does not need an api key, pass a stub key if user did not set one
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("XINFERENCE_API_BASE")
|
||||
or get_secret_str("XINFERENCE_API_BASE")
|
||||
or "http://127.0.0.1:9997/v1"
|
||||
)
|
||||
response = openai_chat_completions.embedding(
|
||||
|
@ -3911,19 +3923,20 @@ def embedding(
|
|||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
aembedding=aembedding,
|
||||
api_key=api_key,
|
||||
)
|
||||
elif custom_llm_provider == "azure_ai":
|
||||
api_base = (
|
||||
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
||||
or litellm.api_base
|
||||
or get_secret("AZURE_AI_API_BASE")
|
||||
or get_secret_str("AZURE_AI_API_BASE")
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or get_secret("AZURE_AI_API_KEY")
|
||||
or get_secret_str("AZURE_AI_API_KEY")
|
||||
)
|
||||
|
||||
## EMBEDDING CALL
|
||||
|
@ -3944,10 +3957,14 @@ def embedding(
|
|||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||
if response is not None and hasattr(response, "_hidden_params"):
|
||||
response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
if response is None:
|
||||
args = locals()
|
||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||
return response
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
litellm_logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
|
@ -4018,7 +4035,11 @@ async def atext_completion(
|
|||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if kwargs.get("stream", False) is True: # return an async generator
|
||||
if (
|
||||
kwargs.get("stream", False) is True
|
||||
or isinstance(response, TextCompletionStreamWrapper)
|
||||
or isinstance(response, CustomStreamWrapper)
|
||||
): # return an async generator
|
||||
return TextCompletionStreamWrapper(
|
||||
completion_stream=_async_streaming(
|
||||
response=response,
|
||||
|
@ -4153,9 +4174,10 @@ def text_completion(
|
|||
Your example of how to use this function goes here.
|
||||
"""
|
||||
if "engine" in kwargs:
|
||||
if model is None:
|
||||
_engine = kwargs["engine"]
|
||||
if model is None and isinstance(_engine, str):
|
||||
# only use engine when model not passed
|
||||
model = kwargs["engine"]
|
||||
model = _engine
|
||||
kwargs.pop("engine")
|
||||
|
||||
text_completion_response = TextCompletionResponse()
|
||||
|
@ -4223,7 +4245,7 @@ def text_completion(
|
|||
def process_prompt(i, individual_prompt):
|
||||
decoded_prompt = tokenizer.decode(individual_prompt)
|
||||
all_params = {**kwargs, **optional_params}
|
||||
response = text_completion(
|
||||
response: TextCompletionResponse = text_completion( # type: ignore
|
||||
model=model,
|
||||
prompt=decoded_prompt,
|
||||
num_retries=3, # ensure this does not fail for the batch
|
||||
|
@ -4292,6 +4314,8 @@ def text_completion(
|
|||
model = "text-completion-openai/" + _model
|
||||
optional_params.pop("custom_llm_provider", None)
|
||||
|
||||
if model is None:
|
||||
raise ValueError("model is not set. Set either via 'model' or 'engine' param.")
|
||||
kwargs["text_completion"] = True
|
||||
response = completion(
|
||||
model=model,
|
||||
|
@ -4302,7 +4326,11 @@ def text_completion(
|
|||
)
|
||||
if kwargs.get("acompletion", False) is True:
|
||||
return response
|
||||
if stream is True or kwargs.get("stream", False) is True:
|
||||
if (
|
||||
stream is True
|
||||
or kwargs.get("stream", False) is True
|
||||
or isinstance(response, CustomStreamWrapper)
|
||||
):
|
||||
response = TextCompletionStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
|
@ -4310,6 +4338,8 @@ def text_completion(
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
return response
|
||||
elif isinstance(response, TextCompletionStreamWrapper):
|
||||
return response
|
||||
transformed_logprobs = None
|
||||
# only supported for TGI models
|
||||
try:
|
||||
|
@ -4424,7 +4454,10 @@ def moderation(
|
|||
):
|
||||
# only supports open ai for now
|
||||
api_key = (
|
||||
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
openai_client = kwargs.get("client", None)
|
||||
|
@ -4433,7 +4466,10 @@ def moderation(
|
|||
api_key=api_key,
|
||||
)
|
||||
|
||||
response = openai_client.moderations.create(input=input, model=model)
|
||||
if model is not None:
|
||||
response = openai_client.moderations.create(input=input, model=model)
|
||||
else:
|
||||
response = openai_client.moderations.create(input=input)
|
||||
return response
|
||||
|
||||
|
||||
|
@ -4441,20 +4477,30 @@ def moderation(
|
|||
async def amoderation(
|
||||
input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs
|
||||
):
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# only supports open ai for now
|
||||
api_key = (
|
||||
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
openai_client = kwargs.get("client", None)
|
||||
if openai_client is None:
|
||||
if openai_client is None or not isinstance(openai_client, AsyncOpenAI):
|
||||
|
||||
# call helper to get OpenAI client
|
||||
# _get_openai_client maintains in-memory caching logic for OpenAI clients
|
||||
openai_client = openai_chat_completions._get_openai_client(
|
||||
_openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore
|
||||
is_async=True,
|
||||
api_key=api_key,
|
||||
)
|
||||
response = await openai_client.moderations.create(input=input, model=model)
|
||||
else:
|
||||
_openai_client = openai_client
|
||||
if model is not None:
|
||||
response = await openai_client.moderations.create(input=input, model=model)
|
||||
else:
|
||||
response = await openai_client.moderations.create(input=input)
|
||||
return response
|
||||
|
||||
|
||||
|
@ -4497,7 +4543,7 @@ async def aimage_generation(*args, **kwargs) -> ImageResponse:
|
|||
init_response = ImageResponse(**init_response)
|
||||
response = init_response
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
response = await init_response # type: ignore
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -4527,7 +4573,6 @@ def image_generation(
|
|||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
litellm_logging_obj=None,
|
||||
custom_llm_provider=None,
|
||||
**kwargs,
|
||||
) -> ImageResponse:
|
||||
|
@ -4543,9 +4588,10 @@ def image_generation(
|
|||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", {})
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
client = kwargs.get("client", None)
|
||||
|
||||
model_response = litellm.utils.ImageResponse()
|
||||
model_response: ImageResponse = litellm.utils.ImageResponse()
|
||||
if model is not None or custom_llm_provider is not None:
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
else:
|
||||
|
@ -4651,25 +4697,27 @@ def image_generation(
|
|||
|
||||
if custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
|
||||
|
||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
|
||||
api_version = (
|
||||
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
||||
api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
)
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
||||
"AZURE_AD_TOKEN"
|
||||
)
|
||||
azure_ad_token = optional_params.pop(
|
||||
"azure_ad_token", None
|
||||
) or get_secret_str("AZURE_AD_TOKEN")
|
||||
|
||||
model_response = azure_chat_completions.image_generation(
|
||||
model=model,
|
||||
|
@ -4714,18 +4762,18 @@ def image_generation(
|
|||
optional_params.pop("vertex_project", None)
|
||||
or optional_params.pop("vertex_ai_project", None)
|
||||
or litellm.vertex_project
|
||||
or get_secret("VERTEXAI_PROJECT")
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.pop("vertex_location", None)
|
||||
or optional_params.pop("vertex_ai_location", None)
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = (
|
||||
optional_params.pop("vertex_credentials", None)
|
||||
or optional_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret("VERTEXAI_CREDENTIALS")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
model_response = vertex_image_generation.image_generation(
|
||||
model=model,
|
||||
|
@ -4786,7 +4834,7 @@ async def atranscription(*args, **kwargs) -> TranscriptionResponse:
|
|||
elif isinstance(init_response, TranscriptionResponse): ## CACHING SCENARIO
|
||||
response = init_response
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
response = await init_response # type: ignore
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -4820,7 +4868,6 @@ def transcription(
|
|||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
litellm_logging_obj: Optional[LiteLLMLoggingObj] = None,
|
||||
custom_llm_provider=None,
|
||||
**kwargs,
|
||||
) -> TranscriptionResponse:
|
||||
|
@ -4830,6 +4877,7 @@ def transcription(
|
|||
Allows router to load balance between them
|
||||
"""
|
||||
atranscription = kwargs.get("atranscription", False)
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
kwargs.get("litellm_call_id", None)
|
||||
kwargs.get("logger_fn", None)
|
||||
kwargs.get("proxy_server_request", None)
|
||||
|
@ -4869,22 +4917,17 @@ def transcription(
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
# optional_params = {
|
||||
# "language": language,
|
||||
# "prompt": prompt,
|
||||
# "response_format": response_format,
|
||||
# "temperature": None, # openai defaults this to 0
|
||||
# }
|
||||
|
||||
response: Optional[TranscriptionResponse] = None
|
||||
if custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
|
||||
api_version = (
|
||||
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
||||
api_version or litellm.api_version or get_secret_str("AZURE_API_VERSION")
|
||||
)
|
||||
|
||||
azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret(
|
||||
azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret_str(
|
||||
"AZURE_AD_TOKEN"
|
||||
)
|
||||
|
||||
|
@ -4892,8 +4935,8 @@ def transcription(
|
|||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
response = azure_audio_transcriptions.audio_transcriptions(
|
||||
model=model,
|
||||
|
@ -4942,6 +4985,9 @@ def transcription(
|
|||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError("Unmapped provider passed in. Unable to get the response.")
|
||||
return response
|
||||
|
||||
|
||||
|
@ -5149,15 +5195,16 @@ def speech(
|
|||
vertex_ai_project = (
|
||||
generic_optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret("VERTEXAI_PROJECT")
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
generic_optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = generic_optional_params.vertex_credentials or get_secret(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
vertex_credentials = (
|
||||
generic_optional_params.vertex_credentials
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
if voice is not None and not isinstance(voice, dict):
|
||||
|
@ -5234,20 +5281,25 @@ async def ahealth_check(
|
|||
if custom_llm_provider == "azure":
|
||||
api_key = (
|
||||
model_params.get("api_key")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
api_base = (
|
||||
api_base: Optional[str] = (
|
||||
model_params.get("api_base")
|
||||
or get_secret("AZURE_API_BASE")
|
||||
or get_secret("AZURE_OPENAI_API_BASE")
|
||||
or get_secret_str("AZURE_API_BASE")
|
||||
or get_secret_str("AZURE_OPENAI_API_BASE")
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Azure API Base cannot be None. Set via 'AZURE_API_BASE' in env var or `.completion(..., api_base=..)`"
|
||||
)
|
||||
|
||||
api_version = (
|
||||
model_params.get("api_version")
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
or get_secret("AZURE_OPENAI_API_VERSION")
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
or get_secret_str("AZURE_OPENAI_API_VERSION")
|
||||
)
|
||||
|
||||
timeout = (
|
||||
|
@ -5273,7 +5325,7 @@ async def ahealth_check(
|
|||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
):
|
||||
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
|
||||
api_key = model_params.get("api_key") or get_secret_str("OPENAI_API_KEY")
|
||||
organization = model_params.get("organization")
|
||||
|
||||
timeout = (
|
||||
|
@ -5282,7 +5334,7 @@ async def ahealth_check(
|
|||
or default_timeout
|
||||
)
|
||||
|
||||
api_base = model_params.get("api_base") or get_secret("OPENAI_API_BASE")
|
||||
api_base = model_params.get("api_base") or get_secret_str("OPENAI_API_BASE")
|
||||
|
||||
if custom_llm_provider == "text-completion-openai":
|
||||
mode = "completion"
|
||||
|
@ -5377,7 +5429,9 @@ def config_completion(**kwargs):
|
|||
)
|
||||
|
||||
|
||||
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List] = None):
|
||||
def stream_chunk_builder_text_completion(
|
||||
chunks: list, messages: Optional[List] = None
|
||||
) -> TextCompletionResponse:
|
||||
id = chunks[0]["id"]
|
||||
object = chunks[0]["object"]
|
||||
created = chunks[0]["created"]
|
||||
|
@ -5446,7 +5500,7 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
|
|||
response["usage"]["total_tokens"] = (
|
||||
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
|
||||
)
|
||||
return response
|
||||
return TextCompletionResponse(**response)
|
||||
|
||||
|
||||
def stream_chunk_builder(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue