forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_bedrock_command_r_support
This commit is contained in:
commit
1d651c6049
82 changed files with 3661 additions and 605 deletions
283
litellm/utils.py
283
litellm/utils.py
|
@ -33,6 +33,10 @@ from dataclasses import (
|
|||
)
|
||||
|
||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.caching import DualCache
|
||||
|
||||
oidc_cache = DualCache()
|
||||
|
||||
try:
|
||||
# this works in python 3.8
|
||||
|
@ -107,7 +111,18 @@ try:
|
|||
except Exception as e:
|
||||
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
|
||||
|
||||
from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO, Iterable
|
||||
from typing import (
|
||||
cast,
|
||||
List,
|
||||
Dict,
|
||||
Union,
|
||||
Optional,
|
||||
Literal,
|
||||
Any,
|
||||
BinaryIO,
|
||||
Iterable,
|
||||
Tuple,
|
||||
)
|
||||
from .caching import Cache
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
@ -2942,6 +2957,7 @@ def client(original_function):
|
|||
)
|
||||
else:
|
||||
return result
|
||||
|
||||
return result
|
||||
|
||||
# Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print
|
||||
|
@ -3045,6 +3061,7 @@ def client(original_function):
|
|||
model_response_object=ModelResponse(),
|
||||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
|
||||
if kwargs.get("stream", False) == True:
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
|
@ -5879,10 +5896,70 @@ def calculate_max_parallel_requests(
|
|||
return None
|
||||
|
||||
|
||||
def _is_region_eu(model_region: str) -> bool:
|
||||
EU_Regions = ["europe", "sweden", "switzerland", "france", "uk"]
|
||||
for region in EU_Regions:
|
||||
if "europe" in model_region.lower():
|
||||
def _get_model_region(
|
||||
custom_llm_provider: str, litellm_params: LiteLLM_Params
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Return the region for a model, for a given provider
|
||||
"""
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
# check 'vertex_location'
|
||||
vertex_ai_location = (
|
||||
litellm_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
or get_secret("VERTEX_LOCATION")
|
||||
)
|
||||
if vertex_ai_location is not None and isinstance(vertex_ai_location, str):
|
||||
return vertex_ai_location
|
||||
elif custom_llm_provider == "bedrock":
|
||||
aws_region_name = litellm_params.aws_region_name
|
||||
if aws_region_name is not None:
|
||||
return aws_region_name
|
||||
elif custom_llm_provider == "watsonx":
|
||||
watsonx_region_name = litellm_params.watsonx_region_name
|
||||
if watsonx_region_name is not None:
|
||||
return watsonx_region_name
|
||||
return litellm_params.region_name
|
||||
|
||||
|
||||
def _is_region_eu(litellm_params: LiteLLM_Params) -> bool:
|
||||
"""
|
||||
Return true/false if a deployment is in the EU
|
||||
"""
|
||||
if litellm_params.region_name == "eu":
|
||||
return True
|
||||
|
||||
## ELSE ##
|
||||
"""
|
||||
- get provider
|
||||
- get provider regions
|
||||
- return true if given region (get_provider_region) in eu region (config.get_eu_regions())
|
||||
"""
|
||||
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=litellm_params.model, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
model_region = _get_model_region(
|
||||
custom_llm_provider=custom_llm_provider, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
if model_region is None:
|
||||
return False
|
||||
|
||||
if custom_llm_provider == "azure":
|
||||
eu_regions = litellm.AzureOpenAIConfig().get_eu_regions()
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
eu_regions = litellm.VertexAIConfig().get_eu_regions()
|
||||
elif custom_llm_provider == "bedrock":
|
||||
eu_regions = litellm.AmazonBedrockGlobalConfig().get_eu_regions()
|
||||
elif custom_llm_provider == "watsonx":
|
||||
eu_regions = litellm.IBMWatsonXAIConfig().get_eu_regions()
|
||||
else:
|
||||
return False
|
||||
|
||||
for region in eu_regions:
|
||||
if region in model_region.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -6308,8 +6385,23 @@ def get_llm_provider(
|
|||
custom_llm_provider: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
litellm_params: Optional[LiteLLM_Params] = None,
|
||||
) -> Tuple[str, str, Optional[str], Optional[str]]:
|
||||
"""
|
||||
Returns the provider for a given model name - e.g. 'azure/chatgpt-v-2' -> 'azure'
|
||||
|
||||
For router -> Can also give the whole litellm param dict -> this function will extract the relevant details
|
||||
"""
|
||||
try:
|
||||
## IF LITELLM PARAMS GIVEN ##
|
||||
if litellm_params is not None:
|
||||
assert (
|
||||
custom_llm_provider is None and api_base is None and api_key is None
|
||||
), "Either pass in litellm_params or the custom_llm_provider/api_base/api_key. Otherwise, these values will be overriden."
|
||||
custom_llm_provider = litellm_params.custom_llm_provider
|
||||
api_base = litellm_params.api_base
|
||||
api_key = litellm_params.api_key
|
||||
|
||||
dynamic_api_key = None
|
||||
# check if llm provider provided
|
||||
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
|
||||
|
@ -6370,7 +6462,8 @@ def get_llm_provider(
|
|||
api_base
|
||||
or get_secret("MISTRAL_AZURE_API_BASE") # for Azure AI Mistral
|
||||
or "https://api.mistral.ai/v1"
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
# if api_base does not end with /v1 we add it
|
||||
if api_base is not None and not api_base.endswith(
|
||||
"/v1"
|
||||
|
@ -6393,10 +6486,30 @@ def get_llm_provider(
|
|||
or get_secret("TOGETHERAI_API_KEY")
|
||||
or get_secret("TOGETHER_AI_TOKEN")
|
||||
)
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
"api base needs to be a string. api_base={}".format(api_base)
|
||||
)
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
elif model.split("/", 1)[0] in litellm.provider_list:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
"api base needs to be a string. api_base={}".format(api_base)
|
||||
)
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
# check if api base is a known openai compatible endpoint
|
||||
if api_base:
|
||||
|
@ -6420,7 +6533,22 @@ def get_llm_provider(
|
|||
elif endpoint == "api.deepseek.com/v1":
|
||||
custom_llm_provider = "deepseek"
|
||||
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
"api base needs to be a string. api_base={}".format(
|
||||
api_base
|
||||
)
|
||||
)
|
||||
if dynamic_api_key is not None and not isinstance(
|
||||
dynamic_api_key, str
|
||||
):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base # type: ignore
|
||||
|
||||
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
||||
## openai - chatcompletion + text completion
|
||||
|
@ -6511,6 +6639,16 @@ def get_llm_provider(
|
|||
),
|
||||
llm_provider="",
|
||||
)
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
"api base needs to be a string. api_base={}".format(api_base)
|
||||
)
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
raise Exception(
|
||||
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
|
||||
dynamic_api_key
|
||||
)
|
||||
)
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
except Exception as e:
|
||||
if isinstance(e, litellm.exceptions.BadRequestError):
|
||||
|
@ -8081,10 +8219,7 @@ def exception_type(
|
|||
+ "Exception"
|
||||
)
|
||||
|
||||
if (
|
||||
"This model's maximum context length is" in error_str
|
||||
or "Request too large" in error_str
|
||||
):
|
||||
if "This model's maximum context length is" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise ContextWindowExceededError(
|
||||
message=f"{exception_provider} - {message} {extra_information}",
|
||||
|
@ -8125,6 +8260,13 @@ def exception_type(
|
|||
model=model,
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif "Request too large" in error_str:
|
||||
raise RateLimitError(
|
||||
message=f"{exception_provider} - {message} {extra_information}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif (
|
||||
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
|
||||
in error_str
|
||||
|
@ -9410,6 +9552,75 @@ def get_secret(
|
|||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
|
||||
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
|
||||
if secret_name.startswith("oidc/"):
|
||||
secret_name_split = secret_name.replace("oidc/", "")
|
||||
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
|
||||
# TODO: Add caching for HTTP requests
|
||||
if oidc_provider == "google":
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
|
||||
response = oidc_client.get(
|
||||
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
|
||||
params={"audience": oidc_aud},
|
||||
headers={"Metadata-Flavor": "Google"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Google OIDC provider failed")
|
||||
elif oidc_provider == "circleci":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "circleci_v2":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "github":
|
||||
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
|
||||
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
|
||||
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
||||
if (
|
||||
actions_id_token_request_url is None
|
||||
or actions_id_token_request_token is None
|
||||
):
|
||||
raise ValueError(
|
||||
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
|
||||
)
|
||||
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
response = oidc_client.get(
|
||||
actions_id_token_request_url,
|
||||
params={"audience": oidc_aud},
|
||||
headers={
|
||||
"Authorization": f"Bearer {actions_id_token_request_token}",
|
||||
"Accept": "application/json; api-version=2.0",
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text["value"]
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Github OIDC provider failed")
|
||||
else:
|
||||
raise ValueError("Unsupported OIDC provider")
|
||||
|
||||
try:
|
||||
if litellm.secret_manager_client is not None:
|
||||
try:
|
||||
|
@ -10364,7 +10575,12 @@ class CustomStreamWrapper:
|
|||
response = chunk.replace("data: ", "").strip()
|
||||
parsed_response = json.loads(response)
|
||||
else:
|
||||
return {"text": "", "is_finished": False}
|
||||
return {
|
||||
"text": "",
|
||||
"is_finished": False,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
}
|
||||
else:
|
||||
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
|
||||
raise ValueError(
|
||||
|
@ -10379,13 +10595,47 @@ class CustomStreamWrapper:
|
|||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
"prompt_tokens": results[0].get("input_token_count", None),
|
||||
"completion_tokens": results[0].get("generated_token_count", None),
|
||||
"prompt_tokens": results[0].get("input_token_count", 0),
|
||||
"completion_tokens": results[0].get("generated_token_count", 0),
|
||||
}
|
||||
return {"text": "", "is_finished": False}
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def handle_clarifai_completion_chunk(self, chunk):
|
||||
try:
|
||||
if isinstance(chunk, dict):
|
||||
parsed_response = chunk
|
||||
if isinstance(chunk, (str, bytes)):
|
||||
if isinstance(chunk, bytes):
|
||||
parsed_response = chunk.decode("utf-8")
|
||||
else:
|
||||
parsed_response = chunk
|
||||
data_json = json.loads(parsed_response)
|
||||
text = (
|
||||
data_json.get("outputs", "")[0]
|
||||
.get("data", "")
|
||||
.get("text", "")
|
||||
.get("raw", "")
|
||||
)
|
||||
prompt_tokens = len(
|
||||
encoding.encode(
|
||||
data_json.get("outputs", "")[0]
|
||||
.get("input", "")
|
||||
.get("data", "")
|
||||
.get("text", "")
|
||||
.get("raw", "")
|
||||
)
|
||||
)
|
||||
completion_tokens = len(encoding.encode(text))
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": True,
|
||||
}
|
||||
except:
|
||||
traceback.print_exc()
|
||||
return ""
|
||||
|
||||
def model_response_creator(self):
|
||||
model_response = ModelResponse(
|
||||
stream=True, model=self.model, stream_options=self.stream_options
|
||||
|
@ -10431,6 +10681,9 @@ class CustomStreamWrapper:
|
|||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "clarifai":
|
||||
response_obj = self.handle_clarifai_completion_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
|
||||
response_obj = self.handle_replicate_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue