Merge branch 'main' into litellm_bedrock_command_r_support

This commit is contained in:
Krish Dholakia 2024-05-11 21:24:42 -07:00 committed by GitHub
commit 1d651c6049
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
82 changed files with 3661 additions and 605 deletions

View file

@ -33,6 +33,10 @@ from dataclasses import (
)
import litellm._service_logger # for storing API inputs, outputs, and metadata
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.caching import DualCache
oidc_cache = DualCache()
try:
# this works in python 3.8
@ -107,7 +111,18 @@ try:
except Exception as e:
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
from typing import cast, List, Dict, Union, Optional, Literal, Any, BinaryIO, Iterable
from typing import (
cast,
List,
Dict,
Union,
Optional,
Literal,
Any,
BinaryIO,
Iterable,
Tuple,
)
from .caching import Cache
from concurrent.futures import ThreadPoolExecutor
@ -2942,6 +2957,7 @@ def client(original_function):
)
else:
return result
return result
# Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print
@ -3045,6 +3061,7 @@ def client(original_function):
model_response_object=ModelResponse(),
stream=kwargs.get("stream", False),
)
if kwargs.get("stream", False) == True:
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
@ -5879,10 +5896,70 @@ def calculate_max_parallel_requests(
return None
def _is_region_eu(model_region: str) -> bool:
EU_Regions = ["europe", "sweden", "switzerland", "france", "uk"]
for region in EU_Regions:
if "europe" in model_region.lower():
def _get_model_region(
custom_llm_provider: str, litellm_params: LiteLLM_Params
) -> Optional[str]:
"""
Return the region for a model, for a given provider
"""
if custom_llm_provider == "vertex_ai":
# check 'vertex_location'
vertex_ai_location = (
litellm_params.vertex_location
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
or get_secret("VERTEX_LOCATION")
)
if vertex_ai_location is not None and isinstance(vertex_ai_location, str):
return vertex_ai_location
elif custom_llm_provider == "bedrock":
aws_region_name = litellm_params.aws_region_name
if aws_region_name is not None:
return aws_region_name
elif custom_llm_provider == "watsonx":
watsonx_region_name = litellm_params.watsonx_region_name
if watsonx_region_name is not None:
return watsonx_region_name
return litellm_params.region_name
def _is_region_eu(litellm_params: LiteLLM_Params) -> bool:
"""
Return true/false if a deployment is in the EU
"""
if litellm_params.region_name == "eu":
return True
## ELSE ##
"""
- get provider
- get provider regions
- return true if given region (get_provider_region) in eu region (config.get_eu_regions())
"""
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=litellm_params.model, litellm_params=litellm_params
)
model_region = _get_model_region(
custom_llm_provider=custom_llm_provider, litellm_params=litellm_params
)
if model_region is None:
return False
if custom_llm_provider == "azure":
eu_regions = litellm.AzureOpenAIConfig().get_eu_regions()
elif custom_llm_provider == "vertex_ai":
eu_regions = litellm.VertexAIConfig().get_eu_regions()
elif custom_llm_provider == "bedrock":
eu_regions = litellm.AmazonBedrockGlobalConfig().get_eu_regions()
elif custom_llm_provider == "watsonx":
eu_regions = litellm.IBMWatsonXAIConfig().get_eu_regions()
else:
return False
for region in eu_regions:
if region in model_region.lower():
return True
return False
@ -6308,8 +6385,23 @@ def get_llm_provider(
custom_llm_provider: Optional[str] = None,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
):
litellm_params: Optional[LiteLLM_Params] = None,
) -> Tuple[str, str, Optional[str], Optional[str]]:
"""
Returns the provider for a given model name - e.g. 'azure/chatgpt-v-2' -> 'azure'
For router -> Can also give the whole litellm param dict -> this function will extract the relevant details
"""
try:
## IF LITELLM PARAMS GIVEN ##
if litellm_params is not None:
assert (
custom_llm_provider is None and api_base is None and api_key is None
), "Either pass in litellm_params or the custom_llm_provider/api_base/api_key. Otherwise, these values will be overriden."
custom_llm_provider = litellm_params.custom_llm_provider
api_base = litellm_params.api_base
api_key = litellm_params.api_key
dynamic_api_key = None
# check if llm provider provided
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
@ -6370,7 +6462,8 @@ def get_llm_provider(
api_base
or get_secret("MISTRAL_AZURE_API_BASE") # for Azure AI Mistral
or "https://api.mistral.ai/v1"
)
) # type: ignore
# if api_base does not end with /v1 we add it
if api_base is not None and not api_base.endswith(
"/v1"
@ -6393,10 +6486,30 @@ def get_llm_provider(
or get_secret("TOGETHERAI_API_KEY")
or get_secret("TOGETHER_AI_TOKEN")
)
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(api_base)
)
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
raise Exception(
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
dynamic_api_key
)
)
return model, custom_llm_provider, dynamic_api_key, api_base
elif model.split("/", 1)[0] in litellm.provider_list:
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(api_base)
)
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
raise Exception(
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
dynamic_api_key
)
)
return model, custom_llm_provider, dynamic_api_key, api_base
# check if api base is a known openai compatible endpoint
if api_base:
@ -6420,7 +6533,22 @@ def get_llm_provider(
elif endpoint == "api.deepseek.com/v1":
custom_llm_provider = "deepseek"
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
return model, custom_llm_provider, dynamic_api_key, api_base
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(
api_base
)
)
if dynamic_api_key is not None and not isinstance(
dynamic_api_key, str
):
raise Exception(
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
dynamic_api_key
)
)
return model, custom_llm_provider, dynamic_api_key, api_base # type: ignore
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
## openai - chatcompletion + text completion
@ -6511,6 +6639,16 @@ def get_llm_provider(
),
llm_provider="",
)
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(api_base)
)
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
raise Exception(
"dynamic_api_key needs to be a string. dynamic_api_key={}".format(
dynamic_api_key
)
)
return model, custom_llm_provider, dynamic_api_key, api_base
except Exception as e:
if isinstance(e, litellm.exceptions.BadRequestError):
@ -8081,10 +8219,7 @@ def exception_type(
+ "Exception"
)
if (
"This model's maximum context length is" in error_str
or "Request too large" in error_str
):
if "This model's maximum context length is" in error_str:
exception_mapping_worked = True
raise ContextWindowExceededError(
message=f"{exception_provider} - {message} {extra_information}",
@ -8125,6 +8260,13 @@ def exception_type(
model=model,
response=original_exception.response,
)
elif "Request too large" in error_str:
raise RateLimitError(
message=f"{exception_provider} - {message} {extra_information}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
)
elif (
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
in error_str
@ -9410,6 +9552,75 @@ def get_secret(
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
if secret_name.startswith("oidc/"):
secret_name_split = secret_name.replace("oidc/", "")
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
# TODO: Add caching for HTTP requests
if oidc_provider == "google":
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
response = oidc_client.get(
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
params={"audience": oidc_aud},
headers={"Metadata-Flavor": "Google"},
)
if response.status_code == 200:
oidc_token = response.text
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
return oidc_token
else:
raise ValueError("Google OIDC provider failed")
elif oidc_provider == "circleci":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
return env_secret
elif oidc_provider == "circleci_v2":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
return env_secret
elif oidc_provider == "github":
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
if (
actions_id_token_request_url is None
or actions_id_token_request_token is None
):
raise ValueError(
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
)
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = oidc_client.get(
actions_id_token_request_url,
params={"audience": oidc_aud},
headers={
"Authorization": f"Bearer {actions_id_token_request_token}",
"Accept": "application/json; api-version=2.0",
},
)
if response.status_code == 200:
oidc_token = response.text["value"]
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
return oidc_token
else:
raise ValueError("Github OIDC provider failed")
else:
raise ValueError("Unsupported OIDC provider")
try:
if litellm.secret_manager_client is not None:
try:
@ -10364,7 +10575,12 @@ class CustomStreamWrapper:
response = chunk.replace("data: ", "").strip()
parsed_response = json.loads(response)
else:
return {"text": "", "is_finished": False}
return {
"text": "",
"is_finished": False,
"prompt_tokens": 0,
"completion_tokens": 0,
}
else:
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
raise ValueError(
@ -10379,13 +10595,47 @@ class CustomStreamWrapper:
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
"prompt_tokens": results[0].get("input_token_count", None),
"completion_tokens": results[0].get("generated_token_count", None),
"prompt_tokens": results[0].get("input_token_count", 0),
"completion_tokens": results[0].get("generated_token_count", 0),
}
return {"text": "", "is_finished": False}
except Exception as e:
raise e
def handle_clarifai_completion_chunk(self, chunk):
try:
if isinstance(chunk, dict):
parsed_response = chunk
if isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes):
parsed_response = chunk.decode("utf-8")
else:
parsed_response = chunk
data_json = json.loads(parsed_response)
text = (
data_json.get("outputs", "")[0]
.get("data", "")
.get("text", "")
.get("raw", "")
)
prompt_tokens = len(
encoding.encode(
data_json.get("outputs", "")[0]
.get("input", "")
.get("data", "")
.get("text", "")
.get("raw", "")
)
)
completion_tokens = len(encoding.encode(text))
return {
"text": text,
"is_finished": True,
}
except:
traceback.print_exc()
return ""
def model_response_creator(self):
model_response = ModelResponse(
stream=True, model=self.model, stream_options=self.stream_options
@ -10431,6 +10681,9 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "clarifai":
response_obj = self.handle_clarifai_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
response_obj = self.handle_replicate_chunk(chunk)
completion_obj["content"] = response_obj["text"]