mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge branch 'main' into litellm_embedding_caching_updates
This commit is contained in:
commit
9ab59045a3
236 changed files with 24483 additions and 2014 deletions
253
litellm/main.py
253
litellm/main.py
|
@ -15,7 +15,7 @@ import dotenv, traceback, random, asyncio, time, contextvars
|
|||
from copy import deepcopy
|
||||
import httpx
|
||||
import litellm
|
||||
|
||||
from ._logging import verbose_logger
|
||||
from litellm import ( # type: ignore
|
||||
client,
|
||||
exception_type,
|
||||
|
@ -83,6 +83,7 @@ from litellm.utils import (
|
|||
TextCompletionResponse,
|
||||
TextChoices,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
read_config_args,
|
||||
Choices,
|
||||
Message,
|
||||
|
@ -275,14 +276,10 @@ async def acompletion(
|
|||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context) # type: ignore
|
||||
# if kwargs.get("stream", False): # return an async generator
|
||||
# return _async_streaming(
|
||||
# response=response,
|
||||
# model=model,
|
||||
# custom_llm_provider=custom_llm_provider,
|
||||
# args=args,
|
||||
# )
|
||||
# else:
|
||||
if isinstance(response, CustomStreamWrapper):
|
||||
response.set_logging_event_loop(
|
||||
loop=loop
|
||||
) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls)
|
||||
return response
|
||||
except Exception as e:
|
||||
custom_llm_provider = custom_llm_provider or "openai"
|
||||
|
@ -345,6 +342,18 @@ def mock_completion(
|
|||
model_response["choices"][0]["message"]["content"] = mock_response
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
|
||||
model_response.usage = Usage(
|
||||
prompt_tokens=10, completion_tokens=20, total_tokens=30
|
||||
)
|
||||
|
||||
try:
|
||||
_, custom_llm_provider, _, _ = litellm.utils.get_llm_provider(model=model)
|
||||
model_response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
||||
except:
|
||||
# dont let setting a hidden param block a mock_respose
|
||||
pass
|
||||
|
||||
return model_response
|
||||
|
||||
except:
|
||||
|
@ -444,9 +453,12 @@ def completion(
|
|||
num_retries = kwargs.get("num_retries", None) ## deprecated
|
||||
max_retries = kwargs.get("max_retries", None)
|
||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||
organization = kwargs.get("organization", None)
|
||||
### CUSTOM MODEL COST ###
|
||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||
input_cost_per_second = kwargs.get("input_cost_per_second", None)
|
||||
output_cost_per_second = kwargs.get("output_cost_per_second", None)
|
||||
### CUSTOM PROMPT TEMPLATE ###
|
||||
initial_prompt_value = kwargs.get("initial_prompt_value", None)
|
||||
roles = kwargs.get("roles", None)
|
||||
|
@ -524,6 +536,8 @@ def completion(
|
|||
"tpm",
|
||||
"input_cost_per_token",
|
||||
"output_cost_per_token",
|
||||
"input_cost_per_second",
|
||||
"output_cost_per_second",
|
||||
"hf_model_name",
|
||||
"model_info",
|
||||
"proxy_server_request",
|
||||
|
@ -536,10 +550,6 @@ def completion(
|
|||
non_default_params = {
|
||||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
if mock_response:
|
||||
return mock_completion(
|
||||
model, messages, stream=stream, mock_response=mock_response
|
||||
)
|
||||
if timeout is None:
|
||||
timeout = (
|
||||
kwargs.get("request_timeout", None) or 600
|
||||
|
@ -577,15 +587,45 @@ def completion(
|
|||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
if model_response is not None and hasattr(model_response, "_hidden_params"):
|
||||
model_response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
||||
model_response._hidden_params["region_name"] = kwargs.get(
|
||||
"aws_region_name", None
|
||||
) # support region-based pricing for bedrock
|
||||
|
||||
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
||||
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||
print_verbose(f"Registering model={model} in model cost map")
|
||||
litellm.register_model(
|
||||
{
|
||||
f"{custom_llm_provider}/{model}": {
|
||||
"input_cost_per_token": input_cost_per_token,
|
||||
"output_cost_per_token": output_cost_per_token,
|
||||
"litellm_provider": custom_llm_provider,
|
||||
},
|
||||
model: {
|
||||
"input_cost_per_token": input_cost_per_token,
|
||||
"output_cost_per_token": output_cost_per_token,
|
||||
"litellm_provider": custom_llm_provider,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
elif (
|
||||
input_cost_per_second is not None
|
||||
): # time based pricing just needs cost in place
|
||||
output_cost_per_second = output_cost_per_second or 0.0
|
||||
litellm.register_model(
|
||||
{
|
||||
f"{custom_llm_provider}/{model}": {
|
||||
"input_cost_per_second": input_cost_per_second,
|
||||
"output_cost_per_second": output_cost_per_second,
|
||||
"litellm_provider": custom_llm_provider,
|
||||
},
|
||||
model: {
|
||||
"input_cost_per_second": input_cost_per_second,
|
||||
"output_cost_per_second": output_cost_per_second,
|
||||
"litellm_provider": custom_llm_provider,
|
||||
},
|
||||
}
|
||||
)
|
||||
### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ###
|
||||
|
@ -674,6 +714,10 @@ def completion(
|
|||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if mock_response:
|
||||
return mock_completion(
|
||||
model, messages, stream=stream, mock_response=mock_response
|
||||
)
|
||||
if custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
@ -692,9 +736,9 @@ def completion(
|
|||
or get_secret("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
||||
"AZURE_AD_TOKEN"
|
||||
)
|
||||
azure_ad_token = optional_params.get("extra_body", {}).pop(
|
||||
"azure_ad_token", None
|
||||
) or get_secret("AZURE_AD_TOKEN")
|
||||
|
||||
headers = headers or litellm.headers
|
||||
|
||||
|
@ -758,7 +802,8 @@ def completion(
|
|||
or "https://api.openai.com/v1"
|
||||
)
|
||||
openai.organization = (
|
||||
litellm.organization
|
||||
organization
|
||||
or litellm.organization
|
||||
or get_secret("OPENAI_ORGANIZATION")
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
|
@ -798,6 +843,7 @@ def completion(
|
|||
timeout=timeout,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
organization=organization,
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING - log the original exception returned
|
||||
|
@ -967,6 +1013,7 @@ def completion(
|
|||
encoding=encoding, # for calculating input/output tokens
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
headers=headers,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
|
@ -1376,11 +1423,29 @@ def completion(
|
|||
acompletion=acompletion,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and acompletion == False
|
||||
):
|
||||
response = CustomStreamWrapper(
|
||||
iter(model_response),
|
||||
model,
|
||||
custom_llm_provider="gemini",
|
||||
logging_obj=logging,
|
||||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_ai_project = litellm.vertex_project or get_secret("VERTEXAI_PROJECT")
|
||||
vertex_ai_location = litellm.vertex_location or get_secret(
|
||||
"VERTEXAI_LOCATION"
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_ai_project", None)
|
||||
or litellm.vertex_project
|
||||
or get_secret("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.pop("vertex_ai_location", None)
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
)
|
||||
|
||||
model_response = vertex_ai.completion(
|
||||
|
@ -1471,19 +1536,22 @@ def completion(
|
|||
if (
|
||||
"stream" in optional_params and optional_params["stream"] == True
|
||||
): ## [BETA]
|
||||
# sagemaker does not support streaming as of now so we're faking streaming:
|
||||
# https://discuss.huggingface.co/t/streaming-output-text-when-deploying-on-sagemaker/39611
|
||||
# "SageMaker is currently not supporting streaming responses."
|
||||
|
||||
# fake streaming for sagemaker
|
||||
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
||||
resp_string = model_response["choices"][0]["message"]["content"]
|
||||
from .llms.sagemaker import TokenIterator
|
||||
|
||||
tokenIterator = TokenIterator(model_response)
|
||||
response = CustomStreamWrapper(
|
||||
resp_string,
|
||||
model,
|
||||
completion_stream=tokenIterator,
|
||||
model=model,
|
||||
custom_llm_provider="sagemaker",
|
||||
logging_obj=logging,
|
||||
)
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
original_response=response,
|
||||
)
|
||||
return response
|
||||
|
||||
## RESPONSE OBJECT
|
||||
|
@ -2146,6 +2214,7 @@ async def aembedding(*args, **kwargs):
|
|||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -2158,6 +2227,8 @@ async def aembedding(*args, **kwargs):
|
|||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if response is not None and hasattr(response, "_hidden_params"):
|
||||
response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
||||
return response
|
||||
except Exception as e:
|
||||
custom_llm_provider = custom_llm_provider or "openai"
|
||||
|
@ -2174,6 +2245,7 @@ def embedding(
|
|||
model,
|
||||
input=[],
|
||||
# Optional params
|
||||
dimensions: Optional[int] = None,
|
||||
timeout=600, # default to 10 minutes
|
||||
# set api_base, api_version, api_key
|
||||
api_base: Optional[str] = None,
|
||||
|
@ -2194,6 +2266,7 @@ def embedding(
|
|||
Parameters:
|
||||
- model: The embedding model to use.
|
||||
- input: The input for which embeddings are to be generated.
|
||||
- dimensions: The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
|
||||
- timeout: The timeout value for the API call, default 10 mins
|
||||
- litellm_call_id: The call ID for litellm logging.
|
||||
- litellm_logging_obj: The litellm logging object.
|
||||
|
@ -2220,8 +2293,14 @@ def embedding(
|
|||
encoding_format = kwargs.get("encoding_format", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
aembedding = kwargs.get("aembedding", None)
|
||||
### CUSTOM MODEL COST ###
|
||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||
input_cost_per_second = kwargs.get("input_cost_per_second", None)
|
||||
output_cost_per_second = kwargs.get("output_cost_per_second", None)
|
||||
openai_params = [
|
||||
"user",
|
||||
"dimensions",
|
||||
"request_timeout",
|
||||
"api_base",
|
||||
"api_version",
|
||||
|
@ -2268,6 +2347,8 @@ def embedding(
|
|||
"tpm",
|
||||
"input_cost_per_token",
|
||||
"output_cost_per_token",
|
||||
"input_cost_per_second",
|
||||
"output_cost_per_second",
|
||||
"hf_model_name",
|
||||
"proxy_server_request",
|
||||
"model_info",
|
||||
|
@ -2288,11 +2369,35 @@ def embedding(
|
|||
api_key=api_key,
|
||||
)
|
||||
optional_params = get_optional_params_embeddings(
|
||||
model=model,
|
||||
user=user,
|
||||
dimensions=dimensions,
|
||||
encoding_format=encoding_format,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**non_default_params,
|
||||
)
|
||||
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
||||
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||
litellm.register_model(
|
||||
{
|
||||
model: {
|
||||
"input_cost_per_token": input_cost_per_token,
|
||||
"output_cost_per_token": output_cost_per_token,
|
||||
"litellm_provider": custom_llm_provider,
|
||||
}
|
||||
}
|
||||
)
|
||||
if input_cost_per_second is not None: # time based pricing just needs cost in place
|
||||
output_cost_per_second = output_cost_per_second or 0.0
|
||||
litellm.register_model(
|
||||
{
|
||||
model: {
|
||||
"input_cost_per_second": input_cost_per_second,
|
||||
"output_cost_per_second": output_cost_per_second,
|
||||
"litellm_provider": custom_llm_provider,
|
||||
}
|
||||
}
|
||||
)
|
||||
try:
|
||||
response = None
|
||||
logging = litellm_logging_obj
|
||||
|
@ -2385,7 +2490,7 @@ def embedding(
|
|||
client=client,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif model in litellm.cohere_embedding_models:
|
||||
elif custom_llm_provider == "cohere":
|
||||
cohere_key = (
|
||||
api_key
|
||||
or litellm.cohere_key
|
||||
|
@ -2427,6 +2532,29 @@ def embedding(
|
|||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_ai_project", None)
|
||||
or litellm.vertex_project
|
||||
or get_secret("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.pop("vertex_ai_location", None)
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
)
|
||||
|
||||
response = vertex_ai.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "oobabooga":
|
||||
response = oobabooga.embedding(
|
||||
model=model,
|
||||
|
@ -2513,6 +2641,8 @@ def embedding(
|
|||
else:
|
||||
args = locals()
|
||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||
if response is not None and hasattr(response, "_hidden_params"):
|
||||
response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
||||
return response
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
|
@ -2523,9 +2653,7 @@ def embedding(
|
|||
)
|
||||
## Map to OpenAI Exception
|
||||
raise exception_type(
|
||||
model=model,
|
||||
original_exception=e,
|
||||
custom_llm_provider="azure" if azure == True else None,
|
||||
model=model, original_exception=e, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
|
||||
|
@ -2914,6 +3042,7 @@ def image_generation(
|
|||
else:
|
||||
model = "dall-e-2"
|
||||
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
||||
model_response._hidden_params["model"] = model
|
||||
openai_params = [
|
||||
"user",
|
||||
"request_timeout",
|
||||
|
@ -2987,7 +3116,7 @@ def image_generation(
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
**non_default_params,
|
||||
)
|
||||
logging = litellm_logging_obj
|
||||
logging: Logging = litellm_logging_obj
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
user=user,
|
||||
|
@ -3051,7 +3180,18 @@ def image_generation(
|
|||
model_response=model_response,
|
||||
aimg_generation=aimg_generation,
|
||||
)
|
||||
|
||||
elif custom_llm_provider == "bedrock":
|
||||
if model is None:
|
||||
raise Exception("Model needs to be set for bedrock")
|
||||
model_response = bedrock.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
logging_obj=litellm_logging_obj,
|
||||
optional_params=optional_params,
|
||||
model_response=model_response,
|
||||
aimg_generation=aimg_generation,
|
||||
)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
## Map to OpenAI Exception
|
||||
|
@ -3068,7 +3208,9 @@ def image_generation(
|
|||
|
||||
async def ahealth_check(
|
||||
model_params: dict,
|
||||
mode: Optional[Literal["completion", "embedding", "image_generation"]] = None,
|
||||
mode: Optional[
|
||||
Literal["completion", "embedding", "image_generation", "chat"]
|
||||
] = None,
|
||||
prompt: Optional[str] = None,
|
||||
input: Optional[List] = None,
|
||||
default_timeout: float = 6000,
|
||||
|
@ -3085,8 +3227,11 @@ async def ahealth_check(
|
|||
if model is None:
|
||||
raise Exception("model not set")
|
||||
|
||||
if model in litellm.model_cost and mode is None:
|
||||
mode = litellm.model_cost[model]["mode"]
|
||||
|
||||
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
mode = mode or "completion" # default to completion calls
|
||||
mode = mode or "chat" # default to chat completion calls
|
||||
|
||||
if custom_llm_provider == "azure":
|
||||
api_key = (
|
||||
|
@ -3126,8 +3271,12 @@ async def ahealth_check(
|
|||
prompt=prompt,
|
||||
input=input,
|
||||
)
|
||||
elif custom_llm_provider == "openai":
|
||||
elif (
|
||||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
):
|
||||
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
|
||||
organization = model_params.get("organization")
|
||||
|
||||
timeout = (
|
||||
model_params.get("timeout")
|
||||
|
@ -3145,8 +3294,12 @@ async def ahealth_check(
|
|||
mode=mode,
|
||||
prompt=prompt,
|
||||
input=input,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
model_params["cache"] = {
|
||||
"no-cache": True
|
||||
} # don't used cached responses for making health check calls
|
||||
if mode == "embedding":
|
||||
model_params.pop("messages", None)
|
||||
model_params["input"] = input
|
||||
|
@ -3169,6 +3322,7 @@ async def ahealth_check(
|
|||
## Set verbose to true -> ```litellm.set_verbose = True```
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
verbose_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
|
@ -3256,7 +3410,23 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
|
|||
return response
|
||||
|
||||
|
||||
def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
|
||||
def stream_chunk_builder(
|
||||
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
|
||||
):
|
||||
model_response = litellm.ModelResponse()
|
||||
### SORT CHUNKS BASED ON CREATED ORDER ##
|
||||
print_verbose("Goes into checking if chunk has hiddden created at param")
|
||||
if chunks[0]._hidden_params.get("created_at", None):
|
||||
print_verbose("Chunks have a created at hidden param")
|
||||
# Sort chunks based on created_at in ascending order
|
||||
chunks = sorted(
|
||||
chunks, key=lambda x: x._hidden_params.get("created_at", float("inf"))
|
||||
)
|
||||
print_verbose("Chunks sorted")
|
||||
|
||||
# set hidden params from chunk to model_response
|
||||
if model_response is not None and hasattr(model_response, "_hidden_params"):
|
||||
model_response._hidden_params = chunks[0].get("_hidden_params", {})
|
||||
id = chunks[0]["id"]
|
||||
object = chunks[0]["object"]
|
||||
created = chunks[0]["created"]
|
||||
|
@ -3427,5 +3597,8 @@ def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
|
|||
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
|
||||
)
|
||||
return convert_to_model_response_object(
|
||||
response_object=response, model_response_object=litellm.ModelResponse()
|
||||
response_object=response,
|
||||
model_response_object=model_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue