diff --git a/docs/my-website/docs/providers/litellm_proxy.md b/docs/my-website/docs/providers/litellm_proxy.md index 2914b09f2..69377b27f 100644 --- a/docs/my-website/docs/providers/litellm_proxy.md +++ b/docs/my-website/docs/providers/litellm_proxy.md @@ -9,13 +9,13 @@ import TabItem from '@theme/TabItem'; ::: -**[LiteLLM Proxy](../simple_proxy) is OpenAI compatible**, you just need the `openai/` prefix before the model +**[LiteLLM Proxy](../simple_proxy) is OpenAI compatible**, you just need the `litellm_proxy/` prefix before the model ## Required Variables ```python -os.environ["OPENAI_API_KEY"] = "" # "sk-1234" your litellm proxy api key -os.environ["OPENAI_API_BASE"] = "" # "http://localhost:4000" your litellm proxy api base +os.environ["LITELLM_PROXY_API_KEY"] = "" # "sk-1234" your litellm proxy api key +os.environ["LITELLM_PROXY_API_BASE"] = "" # "http://localhost:4000" your litellm proxy api base ``` @@ -25,18 +25,18 @@ import os import litellm from litellm import completion -os.environ["OPENAI_API_KEY"] = "" +os.environ["LITELLM_PROXY_API_KEY"] = "" # set custom api base to your proxy # either set .env or litellm.api_base -# os.environ["OPENAI_API_BASE"] = "" +# os.environ["LITELLM_PROXY_API_BASE"] = "" litellm.api_base = "your-openai-proxy-url" messages = [{ "content": "Hello, how are you?","role": "user"}] -# openai call -response = completion(model="openai/your-model-name", messages) +# litellm proxy call +response = completion(model="litellm_proxy/your-model-name", messages) ``` ## Usage - passing `api_base`, `api_key` per request @@ -48,13 +48,13 @@ import os import litellm from litellm import completion -os.environ["OPENAI_API_KEY"] = "" +os.environ["LITELLM_PROXY_API_KEY"] = "" messages = [{ "content": "Hello, how are you?","role": "user"}] -# openai call +# litellm proxy call response = completion( - model="openai/your-model-name", + model="litellm_proxy/your-model-name", messages, api_base = "your-litellm-proxy-url", api_key = "your-litellm-proxy-api-key" @@ -67,13 +67,13 @@ import os import litellm from litellm import completion -os.environ["OPENAI_API_KEY"] = "" +os.environ["LITELLM_PROXY_API_KEY"] = "" messages = [{ "content": "Hello, how are you?","role": "user"}] # openai call response = completion( - model="openai/your-model-name", + model="litellm_proxy/your-model-name", messages, api_base = "your-litellm-proxy-url", stream=True diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md index 87925516a..167e90916 100644 --- a/docs/my-website/docs/routing.md +++ b/docs/my-website/docs/routing.md @@ -183,9 +183,9 @@ model_list = [{ # list of model deployments "api_key": os.getenv("AZURE_API_KEY"), "api_version": os.getenv("AZURE_API_VERSION"), "api_base": os.getenv("AZURE_API_BASE") + "tpm": 100000, + "rpm": 10000, }, - "tpm": 100000, - "rpm": 10000, }, { "model_name": "gpt-3.5-turbo", "litellm_params": { # params for litellm completion/embedding call @@ -193,24 +193,24 @@ model_list = [{ # list of model deployments "api_key": os.getenv("AZURE_API_KEY"), "api_version": os.getenv("AZURE_API_VERSION"), "api_base": os.getenv("AZURE_API_BASE") + "tpm": 100000, + "rpm": 1000, }, - "tpm": 100000, - "rpm": 1000, }, { "model_name": "gpt-3.5-turbo", "litellm_params": { # params for litellm completion/embedding call "model": "gpt-3.5-turbo", "api_key": os.getenv("OPENAI_API_KEY"), + "tpm": 100000, + "rpm": 1000, }, - "tpm": 100000, - "rpm": 1000, }] router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], routing_strategy="usage-based-routing-v2" # 👈 KEY CHANGE - enable_pre_call_check=True, # enables router rate limits for concurrent calls + enable_pre_call_checks=True, # enables router rate limits for concurrent calls ) response = await router.acompletion(model="gpt-3.5-turbo", diff --git a/litellm/__init__.py b/litellm/__init__.py index fb6234e04..f0b930dc1 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -495,6 +495,7 @@ openai_compatible_providers: List = [ "friendliai", "azure_ai", "github", + "litellm_proxy", ] openai_text_completion_compatible_providers: List = ( [ # providers that support `/v1/completions` @@ -748,6 +749,7 @@ class LlmProviders(str, Enum): EMPOWER = "empower" GITHUB = "github" CUSTOM = "custom" + LITELLM_PROXY = "litellm_proxy" provider_list: List[Union[LlmProviders, str]] = list(LlmProviders) diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index eaa8f730d..309eea529 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -243,6 +243,9 @@ def get_llm_provider( elif custom_llm_provider == "github": api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore dynamic_api_key = api_key or get_secret("GITHUB_API_KEY") + elif custom_llm_provider == "litellm_proxy": + api_base = api_base or get_secret("LITELLM_PROXY_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore + dynamic_api_key = api_key or get_secret("LITELLM_PROXY_API_KEY") elif custom_llm_provider == "mistral": # mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai diff --git a/litellm/llms/OpenAI/chat/o1_transformation.py b/litellm/llms/OpenAI/chat/o1_transformation.py index 5c4efbcc6..200097f67 100644 --- a/litellm/llms/OpenAI/chat/o1_transformation.py +++ b/litellm/llms/OpenAI/chat/o1_transformation.py @@ -65,6 +65,7 @@ class OpenAIO1Config(OpenAIGPTConfig): "top_logprobs", "response_format", "stop", + "stream_options", ] return [ diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py index a82da7ad8..d11906b8c 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py @@ -16,6 +16,7 @@ from litellm.types.llms.vertex_ai import ( from litellm.utils import ModelResponse from ..common_utils import VertexAIError +from ..vertex_llm_base import VertexBase from .transformation import ( separate_cached_messages, transform_openai_messages_to_gemini_context_caching, @@ -24,7 +25,7 @@ from .transformation import ( local_cache_obj = Cache(type="local") # only used for calling 'get_cache_key' function -class ContextCachingEndpoints: +class ContextCachingEndpoints(VertexBase): """ Covers context caching endpoints for Vertex AI + Google AI Studio @@ -34,7 +35,7 @@ class ContextCachingEndpoints: def __init__(self) -> None: pass - def _get_token_and_url( + def _get_token_and_url_context_caching( self, gemini_api_key: Optional[str], custom_llm_provider: Literal["gemini"], @@ -57,18 +58,16 @@ class ContextCachingEndpoints: else: raise NotImplementedError - if ( - api_base is not None - ): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317 - if custom_llm_provider == "gemini": - url = "{}/{}".format(api_base, endpoint) - auth_header = ( - gemini_api_key # cloudflare expects api key as bearer token - ) - else: - url = "{}:{}".format(api_base, endpoint) - return auth_header, url + return self._check_custom_proxy( + api_base=api_base, + custom_llm_provider=custom_llm_provider, + gemini_api_key=gemini_api_key, + endpoint=endpoint, + stream=None, + auth_header=auth_header, + url=url, + ) def check_cache( self, @@ -90,7 +89,7 @@ class ContextCachingEndpoints: - None """ - _, url = self._get_token_and_url( + _, url = self._get_token_and_url_context_caching( gemini_api_key=api_key, custom_llm_provider="gemini", api_base=api_base, @@ -125,11 +124,12 @@ class ContextCachingEndpoints: all_cached_items = CachedContentListAllResponseBody(**raw_response) + if "cachedContents" not in all_cached_items: + return None + for cached_item in all_cached_items["cachedContents"]: - if ( - cached_item.get("displayName") is not None - and cached_item["displayName"] == cache_key - ): + display_name = cached_item.get("displayName") + if display_name is not None and display_name == cache_key: return cached_item.get("name") return None @@ -154,7 +154,7 @@ class ContextCachingEndpoints: - None """ - _, url = self._get_token_and_url( + _, url = self._get_token_and_url_context_caching( gemini_api_key=api_key, custom_llm_provider="gemini", api_base=api_base, @@ -189,11 +189,12 @@ class ContextCachingEndpoints: all_cached_items = CachedContentListAllResponseBody(**raw_response) + if "cachedContents" not in all_cached_items: + return None + for cached_item in all_cached_items["cachedContents"]: - if ( - cached_item.get("displayName") is not None - and cached_item["displayName"] == cache_key - ): + display_name = cached_item.get("displayName") + if display_name is not None and display_name == cache_key: return cached_item.get("name") return None @@ -224,7 +225,7 @@ class ContextCachingEndpoints: return messages, cached_content ## AUTHORIZATION ## - token, url = self._get_token_and_url( + token, url = self._get_token_and_url_context_caching( gemini_api_key=api_key, custom_llm_provider="gemini", api_base=api_base, @@ -329,7 +330,7 @@ class ContextCachingEndpoints: return messages, cached_content ## AUTHORIZATION ## - token, url = self._get_token_and_url( + token, url = self._get_token_and_url_context_caching( gemini_api_key=api_key, custom_llm_provider="gemini", api_base=api_base, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py index dbd19b8c3..740bdca5c 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py @@ -86,7 +86,7 @@ class VertexBase(BaseLLM): ) if project_id is None: - project_id = creds.project_id + project_id = getattr(creds, "project_id", None) else: creds, creds_project_id = google_auth.default( quota_project_id=project_id, @@ -95,7 +95,7 @@ class VertexBase(BaseLLM): if project_id is None: project_id = creds_project_id - creds.refresh(Request()) + creds.refresh(Request()) # type: ignore if not project_id: raise ValueError("Could not resolve project_id") @@ -169,6 +169,39 @@ class VertexBase(BaseLLM): return True return False + def _check_custom_proxy( + self, + api_base: Optional[str], + custom_llm_provider: str, + gemini_api_key: Optional[str], + endpoint: str, + stream: Optional[bool], + auth_header: Optional[str], + url: str, + ) -> Tuple[Optional[str], str]: + """ + for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317 + + ## Returns + - (auth_header, url) - Tuple[Optional[str], str] + """ + if api_base: + if custom_llm_provider == "gemini": + url = "{}:{}".format(api_base, endpoint) + if gemini_api_key is None: + raise ValueError( + "Missing gemini_api_key, please set `GEMINI_API_KEY`" + ) + auth_header = ( + gemini_api_key # cloudflare expects api key as bearer token + ) + else: + url = "{}:{}".format(api_base, endpoint) + + if stream is True: + url = url + "?alt=sse" + return auth_header, url + def _get_token_and_url( self, model: str, @@ -215,25 +248,15 @@ class VertexBase(BaseLLM): vertex_api_version=version, ) - if ( - api_base is not None - ): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317 - if custom_llm_provider == "gemini": - url = "{}:{}".format(api_base, endpoint) - if gemini_api_key is None: - raise ValueError( - "Missing gemini_api_key, please set `GEMINI_API_KEY`" - ) - auth_header = ( - gemini_api_key # cloudflare expects api key as bearer token - ) - else: - url = "{}:{}".format(api_base, endpoint) - - if stream is True: - url = url + "?alt=sse" - - return auth_header, url + return self._check_custom_proxy( + api_base=api_base, + auth_header=auth_header, + custom_llm_provider=custom_llm_provider, + gemini_api_key=gemini_api_key, + endpoint=endpoint, + stream=stream, + url=url, + ) async def _ensure_access_token_async( self, diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index ee375a875..b174f1fc2 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1422,6 +1422,13 @@ class UserAPIKeyAuth( arbitrary_types_allowed = True +class UserInfoResponse(LiteLLMBase): + user_id: Optional[str] + user_info: Optional[Union[dict, BaseModel]] + keys: List + teams: List + + class LiteLLM_Config(LiteLLMBase): param_name: str param_value: Dict diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 771234557..af57d849f 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -15,6 +15,7 @@ import copy import json import re import secrets +import time import traceback import uuid from datetime import datetime, timedelta, timezone @@ -274,10 +275,23 @@ async def ui_get_available_role( return _data_to_return +def get_team_from_list( + team_list: Optional[List[LiteLLM_TeamTable]], team_id: str +) -> Optional[LiteLLM_TeamTable]: + if team_list is None: + return None + + for team in team_list: + if team.team_id == team_id: + return team + return None + + @router.get( "/user/info", tags=["Internal User management"], dependencies=[Depends(user_api_key_auth)], + response_model=UserInfoResponse, ) @management_endpoint_wrapper async def user_info( @@ -337,7 +351,7 @@ async def user_info( ## GET ALL TEAMS ## team_list = [] team_id_list = [] - # _DEPRECATED_ check if user in 'member' field + # get all teams user belongs to teams_1 = await prisma_client.get_data( user_id=user_id, table_name="team", query_type="find_all" ) @@ -414,13 +428,13 @@ async def user_info( if ( key.token == litellm_master_key_hash and general_settings.get("disable_master_key_return", False) - == True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui + is True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui ): continue try: key = key.model_dump() # noqa - except: + except Exception: # if using pydantic v1 key = key.dict() if ( @@ -428,29 +442,29 @@ async def user_info( and key["team_id"] is not None and key["team_id"] != "litellm-dashboard" ): - team_info = await prisma_client.get_data( - team_id=key["team_id"], table_name="team" + team_info = get_team_from_list( + team_list=teams_1, team_id=key["team_id"] ) - team_alias = getattr(team_info, "team_alias", None) - key["team_alias"] = team_alias + if team_info is not None: + team_alias = getattr(team_info, "team_alias", None) + key["team_alias"] = team_alias + else: + key["team_alias"] = None else: key["team_alias"] = "None" returned_keys.append(key) - response_data = { - "user_id": user_id, - "user_info": user_info, - "keys": returned_keys, - "teams": team_list, - } + response_data = UserInfoResponse( + user_id=user_id, user_info=user_info, keys=returned_keys, teams=team_list + ) + return response_data except Exception as e: - verbose_proxy_logger.error( + verbose_proxy_logger.exception( "litellm.proxy.proxy_server.user_info(): Exception occured - {}".format( str(e) ) ) - verbose_proxy_logger.debug(traceback.format_exc()) if isinstance(e, HTTPException): raise ProxyException( message=getattr(e, "detail", f"Authentication Error({str(e)})"), diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 61e6879fd..686793e7f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -248,7 +248,7 @@ from litellm.secret_managers.aws_secret_manager import ( load_aws_secret_manager, ) from litellm.secret_managers.google_kms import load_google_kms -from litellm.secret_managers.main import get_secret, str_to_bool +from litellm.secret_managers.main import get_secret, get_secret_str, str_to_bool from litellm.types.llms.anthropic import ( AnthropicMessagesRequest, AnthropicResponse, @@ -2728,9 +2728,21 @@ async def startup_event(): ### LOAD CONFIG ### worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore + env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH") verbose_proxy_logger.debug("worker_config: %s", worker_config) # check if it's a valid file path - if worker_config is not None: + if env_config_yaml is not None: + if os.path.isfile(env_config_yaml) and proxy_config.is_yaml( + config_file_path=env_config_yaml + ): + ( + llm_router, + llm_model_list, + general_settings, + ) = await proxy_config.load_config( + router=llm_router, config_file_path=env_config_yaml + ) + elif worker_config is not None: if ( isinstance(worker_config, str) and os.path.isfile(worker_config) diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index d65954127..5dfdd7c4b 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -31,6 +31,7 @@ from litellm import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( _gemini_convert_messages_with_history, ) +from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase from litellm.tests.test_streaming import streaming_format_tests litellm.num_retries = 3 @@ -2869,3 +2870,24 @@ def test_gemini_finetuned_endpoint(base_model, metadata): assert mock_client.call_args.kwargs["url"].endswith( "endpoints/4965075652664360960:generateContent" ) + + +@pytest.mark.parametrize("api_base", ["", None, "my-custom-proxy-base"]) +def test_custom_api_base(api_base): + stream = None + test_endpoint = "my-fake-endpoint" + vertex_base = VertexBase() + auth_header, url = vertex_base._check_custom_proxy( + api_base=api_base, + custom_llm_provider="gemini", + gemini_api_key="12324", + endpoint="", + stream=stream, + auth_header=None, + url="my-fake-endpoint", + ) + + if api_base: + assert url == api_base + ":" + else: + assert url == test_endpoint diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 13bc92156..12ac3e0b4 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1933,6 +1933,41 @@ async def test_openai_compatible_custom_api_base(provider): assert "hello" in mock_call.call_args.kwargs["extra_body"] +@pytest.mark.asyncio +async def test_litellm_gateway_from_sdk(): + litellm.set_verbose = True + messages = [ + { + "role": "user", + "content": "Hello world", + } + ] + from openai import OpenAI + + openai_client = OpenAI(api_key="fake-key") + + with patch.object( + openai_client.chat.completions, "create", new=MagicMock() + ) as mock_call: + try: + completion( + model="litellm_proxy/my-vllm-model", + messages=messages, + response_format={"type": "json_object"}, + client=openai_client, + api_base="my-custom-api-base", + hello="world", + ) + except Exception as e: + print(e) + + mock_call.assert_called_once() + + print("Call KWARGS - {}".format(mock_call.call_args.kwargs)) + + assert "hello" in mock_call.call_args.kwargs["extra_body"] + + # ################### Hugging Face Conversational models ######################## # def hf_test_completion_conv(): # try: diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 0a9264c9e..25fa83f45 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -256,7 +256,7 @@ def test_generate_and_call_with_valid_key(prisma_client, api_route): # check /user/info to verify user_role was set correctly new_user_info = await user_info(user_id=user_id) - new_user_info = new_user_info["user_info"] + new_user_info = new_user_info.user_info print("new_user_info=", new_user_info) assert new_user_info.user_role == LitellmUserRoles.INTERNAL_USER assert new_user_info.user_id == user_id diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 3ec9a96ba..f6848d97e 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -2154,7 +2154,7 @@ def test_openai_chat_completion_complete_response_call(): # test_openai_chat_completion_complete_response_call() @pytest.mark.parametrize( "model", - ["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307"], # + ["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307", "o1-preview"], # ) @pytest.mark.parametrize( "sync",