mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
LiteLLM Minor Fixes & Improvements (09/20/2024) (#5807)
* fix(vertex_llm_base.py): Handle api_base = "" Fixes https://github.com/BerriAI/litellm/issues/5798 * fix(o1_transformation.py): handle stream_options not being supported https://github.com/BerriAI/litellm/issues/5803 * docs(routing.md): fix docs Closes https://github.com/BerriAI/litellm/issues/5808 * perf(internal_user_endpoints.py): reduce db calls for getting team_alias for a key Use the list gotten earlier in `/user/info` endpoint Reduces ui keys tab load time to 800ms (prev. 28s+) * feat(proxy_server.py): support CONFIG_FILE_PATH as env var Closes https://github.com/BerriAI/litellm/issues/5744 * feat(get_llm_provider_logic.py): add `litellm_proxy/` as a known openai-compatible route simplifies calling litellm proxy Reduces confusion when calling models on litellm proxy from litellm sdk * docs(litellm_proxy.md): cleanup docs * fix(internal_user_endpoints.py): fix pydantic obj * test(test_key_generate_prisma.py): fix test
This commit is contained in:
parent
0c488cf4ca
commit
d6ca7fed18
14 changed files with 204 additions and 84 deletions
|
@ -9,13 +9,13 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
**[LiteLLM Proxy](../simple_proxy) is OpenAI compatible**, you just need the `openai/` prefix before the model
|
**[LiteLLM Proxy](../simple_proxy) is OpenAI compatible**, you just need the `litellm_proxy/` prefix before the model
|
||||||
|
|
||||||
## Required Variables
|
## Required Variables
|
||||||
|
|
||||||
```python
|
```python
|
||||||
os.environ["OPENAI_API_KEY"] = "" # "sk-1234" your litellm proxy api key
|
os.environ["LITELLM_PROXY_API_KEY"] = "" # "sk-1234" your litellm proxy api key
|
||||||
os.environ["OPENAI_API_BASE"] = "" # "http://localhost:4000" your litellm proxy api base
|
os.environ["LITELLM_PROXY_API_BASE"] = "" # "http://localhost:4000" your litellm proxy api base
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,18 +25,18 @@ import os
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
|
||||||
os.environ["OPENAI_API_KEY"] = ""
|
os.environ["LITELLM_PROXY_API_KEY"] = ""
|
||||||
|
|
||||||
# set custom api base to your proxy
|
# set custom api base to your proxy
|
||||||
# either set .env or litellm.api_base
|
# either set .env or litellm.api_base
|
||||||
# os.environ["OPENAI_API_BASE"] = ""
|
# os.environ["LITELLM_PROXY_API_BASE"] = ""
|
||||||
litellm.api_base = "your-openai-proxy-url"
|
litellm.api_base = "your-openai-proxy-url"
|
||||||
|
|
||||||
|
|
||||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||||
|
|
||||||
# openai call
|
# litellm proxy call
|
||||||
response = completion(model="openai/your-model-name", messages)
|
response = completion(model="litellm_proxy/your-model-name", messages)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage - passing `api_base`, `api_key` per request
|
## Usage - passing `api_base`, `api_key` per request
|
||||||
|
@ -48,13 +48,13 @@ import os
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
|
||||||
os.environ["OPENAI_API_KEY"] = ""
|
os.environ["LITELLM_PROXY_API_KEY"] = ""
|
||||||
|
|
||||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||||
|
|
||||||
# openai call
|
# litellm proxy call
|
||||||
response = completion(
|
response = completion(
|
||||||
model="openai/your-model-name",
|
model="litellm_proxy/your-model-name",
|
||||||
messages,
|
messages,
|
||||||
api_base = "your-litellm-proxy-url",
|
api_base = "your-litellm-proxy-url",
|
||||||
api_key = "your-litellm-proxy-api-key"
|
api_key = "your-litellm-proxy-api-key"
|
||||||
|
@ -67,13 +67,13 @@ import os
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
|
||||||
os.environ["OPENAI_API_KEY"] = ""
|
os.environ["LITELLM_PROXY_API_KEY"] = ""
|
||||||
|
|
||||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||||
|
|
||||||
# openai call
|
# openai call
|
||||||
response = completion(
|
response = completion(
|
||||||
model="openai/your-model-name",
|
model="litellm_proxy/your-model-name",
|
||||||
messages,
|
messages,
|
||||||
api_base = "your-litellm-proxy-url",
|
api_base = "your-litellm-proxy-url",
|
||||||
stream=True
|
stream=True
|
||||||
|
|
|
@ -183,9 +183,9 @@ model_list = [{ # list of model deployments
|
||||||
"api_key": os.getenv("AZURE_API_KEY"),
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
"api_base": os.getenv("AZURE_API_BASE")
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
},
|
|
||||||
"tpm": 100000,
|
"tpm": 100000,
|
||||||
"rpm": 10000,
|
"rpm": 10000,
|
||||||
|
},
|
||||||
}, {
|
}, {
|
||||||
"model_name": "gpt-3.5-turbo",
|
"model_name": "gpt-3.5-turbo",
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
@ -193,24 +193,24 @@ model_list = [{ # list of model deployments
|
||||||
"api_key": os.getenv("AZURE_API_KEY"),
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
"api_base": os.getenv("AZURE_API_BASE")
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
},
|
|
||||||
"tpm": 100000,
|
"tpm": 100000,
|
||||||
"rpm": 1000,
|
"rpm": 1000,
|
||||||
|
},
|
||||||
}, {
|
}, {
|
||||||
"model_name": "gpt-3.5-turbo",
|
"model_name": "gpt-3.5-turbo",
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
},
|
|
||||||
"tpm": 100000,
|
"tpm": 100000,
|
||||||
"rpm": 1000,
|
"rpm": 1000,
|
||||||
|
},
|
||||||
}]
|
}]
|
||||||
router = Router(model_list=model_list,
|
router = Router(model_list=model_list,
|
||||||
redis_host=os.environ["REDIS_HOST"],
|
redis_host=os.environ["REDIS_HOST"],
|
||||||
redis_password=os.environ["REDIS_PASSWORD"],
|
redis_password=os.environ["REDIS_PASSWORD"],
|
||||||
redis_port=os.environ["REDIS_PORT"],
|
redis_port=os.environ["REDIS_PORT"],
|
||||||
routing_strategy="usage-based-routing-v2" # 👈 KEY CHANGE
|
routing_strategy="usage-based-routing-v2" # 👈 KEY CHANGE
|
||||||
enable_pre_call_check=True, # enables router rate limits for concurrent calls
|
enable_pre_call_checks=True, # enables router rate limits for concurrent calls
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await router.acompletion(model="gpt-3.5-turbo",
|
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||||
|
|
|
@ -495,6 +495,7 @@ openai_compatible_providers: List = [
|
||||||
"friendliai",
|
"friendliai",
|
||||||
"azure_ai",
|
"azure_ai",
|
||||||
"github",
|
"github",
|
||||||
|
"litellm_proxy",
|
||||||
]
|
]
|
||||||
openai_text_completion_compatible_providers: List = (
|
openai_text_completion_compatible_providers: List = (
|
||||||
[ # providers that support `/v1/completions`
|
[ # providers that support `/v1/completions`
|
||||||
|
@ -748,6 +749,7 @@ class LlmProviders(str, Enum):
|
||||||
EMPOWER = "empower"
|
EMPOWER = "empower"
|
||||||
GITHUB = "github"
|
GITHUB = "github"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
LITELLM_PROXY = "litellm_proxy"
|
||||||
|
|
||||||
|
|
||||||
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
|
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
|
||||||
|
|
|
@ -243,6 +243,9 @@ def get_llm_provider(
|
||||||
elif custom_llm_provider == "github":
|
elif custom_llm_provider == "github":
|
||||||
api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
|
api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
|
||||||
dynamic_api_key = api_key or get_secret("GITHUB_API_KEY")
|
dynamic_api_key = api_key or get_secret("GITHUB_API_KEY")
|
||||||
|
elif custom_llm_provider == "litellm_proxy":
|
||||||
|
api_base = api_base or get_secret("LITELLM_PROXY_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
|
||||||
|
dynamic_api_key = api_key or get_secret("LITELLM_PROXY_API_KEY")
|
||||||
|
|
||||||
elif custom_llm_provider == "mistral":
|
elif custom_llm_provider == "mistral":
|
||||||
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai
|
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai
|
||||||
|
|
|
@ -65,6 +65,7 @@ class OpenAIO1Config(OpenAIGPTConfig):
|
||||||
"top_logprobs",
|
"top_logprobs",
|
||||||
"response_format",
|
"response_format",
|
||||||
"stop",
|
"stop",
|
||||||
|
"stream_options",
|
||||||
]
|
]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|
|
@ -16,6 +16,7 @@ from litellm.types.llms.vertex_ai import (
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
|
|
||||||
from ..common_utils import VertexAIError
|
from ..common_utils import VertexAIError
|
||||||
|
from ..vertex_llm_base import VertexBase
|
||||||
from .transformation import (
|
from .transformation import (
|
||||||
separate_cached_messages,
|
separate_cached_messages,
|
||||||
transform_openai_messages_to_gemini_context_caching,
|
transform_openai_messages_to_gemini_context_caching,
|
||||||
|
@ -24,7 +25,7 @@ from .transformation import (
|
||||||
local_cache_obj = Cache(type="local") # only used for calling 'get_cache_key' function
|
local_cache_obj = Cache(type="local") # only used for calling 'get_cache_key' function
|
||||||
|
|
||||||
|
|
||||||
class ContextCachingEndpoints:
|
class ContextCachingEndpoints(VertexBase):
|
||||||
"""
|
"""
|
||||||
Covers context caching endpoints for Vertex AI + Google AI Studio
|
Covers context caching endpoints for Vertex AI + Google AI Studio
|
||||||
|
|
||||||
|
@ -34,7 +35,7 @@ class ContextCachingEndpoints:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_token_and_url(
|
def _get_token_and_url_context_caching(
|
||||||
self,
|
self,
|
||||||
gemini_api_key: Optional[str],
|
gemini_api_key: Optional[str],
|
||||||
custom_llm_provider: Literal["gemini"],
|
custom_llm_provider: Literal["gemini"],
|
||||||
|
@ -57,18 +58,16 @@ class ContextCachingEndpoints:
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
if (
|
|
||||||
api_base is not None
|
|
||||||
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
|
|
||||||
if custom_llm_provider == "gemini":
|
|
||||||
url = "{}/{}".format(api_base, endpoint)
|
|
||||||
auth_header = (
|
|
||||||
gemini_api_key # cloudflare expects api key as bearer token
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
url = "{}:{}".format(api_base, endpoint)
|
|
||||||
|
|
||||||
return auth_header, url
|
return self._check_custom_proxy(
|
||||||
|
api_base=api_base,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
gemini_api_key=gemini_api_key,
|
||||||
|
endpoint=endpoint,
|
||||||
|
stream=None,
|
||||||
|
auth_header=auth_header,
|
||||||
|
url=url,
|
||||||
|
)
|
||||||
|
|
||||||
def check_cache(
|
def check_cache(
|
||||||
self,
|
self,
|
||||||
|
@ -90,7 +89,7 @@ class ContextCachingEndpoints:
|
||||||
- None
|
- None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_, url = self._get_token_and_url(
|
_, url = self._get_token_and_url_context_caching(
|
||||||
gemini_api_key=api_key,
|
gemini_api_key=api_key,
|
||||||
custom_llm_provider="gemini",
|
custom_llm_provider="gemini",
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -125,11 +124,12 @@ class ContextCachingEndpoints:
|
||||||
|
|
||||||
all_cached_items = CachedContentListAllResponseBody(**raw_response)
|
all_cached_items = CachedContentListAllResponseBody(**raw_response)
|
||||||
|
|
||||||
|
if "cachedContents" not in all_cached_items:
|
||||||
|
return None
|
||||||
|
|
||||||
for cached_item in all_cached_items["cachedContents"]:
|
for cached_item in all_cached_items["cachedContents"]:
|
||||||
if (
|
display_name = cached_item.get("displayName")
|
||||||
cached_item.get("displayName") is not None
|
if display_name is not None and display_name == cache_key:
|
||||||
and cached_item["displayName"] == cache_key
|
|
||||||
):
|
|
||||||
return cached_item.get("name")
|
return cached_item.get("name")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
@ -154,7 +154,7 @@ class ContextCachingEndpoints:
|
||||||
- None
|
- None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_, url = self._get_token_and_url(
|
_, url = self._get_token_and_url_context_caching(
|
||||||
gemini_api_key=api_key,
|
gemini_api_key=api_key,
|
||||||
custom_llm_provider="gemini",
|
custom_llm_provider="gemini",
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -189,11 +189,12 @@ class ContextCachingEndpoints:
|
||||||
|
|
||||||
all_cached_items = CachedContentListAllResponseBody(**raw_response)
|
all_cached_items = CachedContentListAllResponseBody(**raw_response)
|
||||||
|
|
||||||
|
if "cachedContents" not in all_cached_items:
|
||||||
|
return None
|
||||||
|
|
||||||
for cached_item in all_cached_items["cachedContents"]:
|
for cached_item in all_cached_items["cachedContents"]:
|
||||||
if (
|
display_name = cached_item.get("displayName")
|
||||||
cached_item.get("displayName") is not None
|
if display_name is not None and display_name == cache_key:
|
||||||
and cached_item["displayName"] == cache_key
|
|
||||||
):
|
|
||||||
return cached_item.get("name")
|
return cached_item.get("name")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
@ -224,7 +225,7 @@ class ContextCachingEndpoints:
|
||||||
return messages, cached_content
|
return messages, cached_content
|
||||||
|
|
||||||
## AUTHORIZATION ##
|
## AUTHORIZATION ##
|
||||||
token, url = self._get_token_and_url(
|
token, url = self._get_token_and_url_context_caching(
|
||||||
gemini_api_key=api_key,
|
gemini_api_key=api_key,
|
||||||
custom_llm_provider="gemini",
|
custom_llm_provider="gemini",
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -329,7 +330,7 @@ class ContextCachingEndpoints:
|
||||||
return messages, cached_content
|
return messages, cached_content
|
||||||
|
|
||||||
## AUTHORIZATION ##
|
## AUTHORIZATION ##
|
||||||
token, url = self._get_token_and_url(
|
token, url = self._get_token_and_url_context_caching(
|
||||||
gemini_api_key=api_key,
|
gemini_api_key=api_key,
|
||||||
custom_llm_provider="gemini",
|
custom_llm_provider="gemini",
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
|
|
@ -86,7 +86,7 @@ class VertexBase(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
if project_id is None:
|
if project_id is None:
|
||||||
project_id = creds.project_id
|
project_id = getattr(creds, "project_id", None)
|
||||||
else:
|
else:
|
||||||
creds, creds_project_id = google_auth.default(
|
creds, creds_project_id = google_auth.default(
|
||||||
quota_project_id=project_id,
|
quota_project_id=project_id,
|
||||||
|
@ -95,7 +95,7 @@ class VertexBase(BaseLLM):
|
||||||
if project_id is None:
|
if project_id is None:
|
||||||
project_id = creds_project_id
|
project_id = creds_project_id
|
||||||
|
|
||||||
creds.refresh(Request())
|
creds.refresh(Request()) # type: ignore
|
||||||
|
|
||||||
if not project_id:
|
if not project_id:
|
||||||
raise ValueError("Could not resolve project_id")
|
raise ValueError("Could not resolve project_id")
|
||||||
|
@ -169,6 +169,39 @@ class VertexBase(BaseLLM):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _check_custom_proxy(
|
||||||
|
self,
|
||||||
|
api_base: Optional[str],
|
||||||
|
custom_llm_provider: str,
|
||||||
|
gemini_api_key: Optional[str],
|
||||||
|
endpoint: str,
|
||||||
|
stream: Optional[bool],
|
||||||
|
auth_header: Optional[str],
|
||||||
|
url: str,
|
||||||
|
) -> Tuple[Optional[str], str]:
|
||||||
|
"""
|
||||||
|
for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
|
||||||
|
|
||||||
|
## Returns
|
||||||
|
- (auth_header, url) - Tuple[Optional[str], str]
|
||||||
|
"""
|
||||||
|
if api_base:
|
||||||
|
if custom_llm_provider == "gemini":
|
||||||
|
url = "{}:{}".format(api_base, endpoint)
|
||||||
|
if gemini_api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
|
||||||
|
)
|
||||||
|
auth_header = (
|
||||||
|
gemini_api_key # cloudflare expects api key as bearer token
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
url = "{}:{}".format(api_base, endpoint)
|
||||||
|
|
||||||
|
if stream is True:
|
||||||
|
url = url + "?alt=sse"
|
||||||
|
return auth_header, url
|
||||||
|
|
||||||
def _get_token_and_url(
|
def _get_token_and_url(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -215,25 +248,15 @@ class VertexBase(BaseLLM):
|
||||||
vertex_api_version=version,
|
vertex_api_version=version,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
return self._check_custom_proxy(
|
||||||
api_base is not None
|
api_base=api_base,
|
||||||
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
|
auth_header=auth_header,
|
||||||
if custom_llm_provider == "gemini":
|
custom_llm_provider=custom_llm_provider,
|
||||||
url = "{}:{}".format(api_base, endpoint)
|
gemini_api_key=gemini_api_key,
|
||||||
if gemini_api_key is None:
|
endpoint=endpoint,
|
||||||
raise ValueError(
|
stream=stream,
|
||||||
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
|
url=url,
|
||||||
)
|
)
|
||||||
auth_header = (
|
|
||||||
gemini_api_key # cloudflare expects api key as bearer token
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
url = "{}:{}".format(api_base, endpoint)
|
|
||||||
|
|
||||||
if stream is True:
|
|
||||||
url = url + "?alt=sse"
|
|
||||||
|
|
||||||
return auth_header, url
|
|
||||||
|
|
||||||
async def _ensure_access_token_async(
|
async def _ensure_access_token_async(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1422,6 +1422,13 @@ class UserAPIKeyAuth(
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class UserInfoResponse(LiteLLMBase):
|
||||||
|
user_id: Optional[str]
|
||||||
|
user_info: Optional[Union[dict, BaseModel]]
|
||||||
|
keys: List
|
||||||
|
teams: List
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM_Config(LiteLLMBase):
|
class LiteLLM_Config(LiteLLMBase):
|
||||||
param_name: str
|
param_name: str
|
||||||
param_value: Dict
|
param_value: Dict
|
||||||
|
|
|
@ -15,6 +15,7 @@ import copy
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
@ -274,10 +275,23 @@ async def ui_get_available_role(
|
||||||
return _data_to_return
|
return _data_to_return
|
||||||
|
|
||||||
|
|
||||||
|
def get_team_from_list(
|
||||||
|
team_list: Optional[List[LiteLLM_TeamTable]], team_id: str
|
||||||
|
) -> Optional[LiteLLM_TeamTable]:
|
||||||
|
if team_list is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for team in team_list:
|
||||||
|
if team.team_id == team_id:
|
||||||
|
return team
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/user/info",
|
"/user/info",
|
||||||
tags=["Internal User management"],
|
tags=["Internal User management"],
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
response_model=UserInfoResponse,
|
||||||
)
|
)
|
||||||
@management_endpoint_wrapper
|
@management_endpoint_wrapper
|
||||||
async def user_info(
|
async def user_info(
|
||||||
|
@ -337,7 +351,7 @@ async def user_info(
|
||||||
## GET ALL TEAMS ##
|
## GET ALL TEAMS ##
|
||||||
team_list = []
|
team_list = []
|
||||||
team_id_list = []
|
team_id_list = []
|
||||||
# _DEPRECATED_ check if user in 'member' field
|
# get all teams user belongs to
|
||||||
teams_1 = await prisma_client.get_data(
|
teams_1 = await prisma_client.get_data(
|
||||||
user_id=user_id, table_name="team", query_type="find_all"
|
user_id=user_id, table_name="team", query_type="find_all"
|
||||||
)
|
)
|
||||||
|
@ -414,13 +428,13 @@ async def user_info(
|
||||||
if (
|
if (
|
||||||
key.token == litellm_master_key_hash
|
key.token == litellm_master_key_hash
|
||||||
and general_settings.get("disable_master_key_return", False)
|
and general_settings.get("disable_master_key_return", False)
|
||||||
== True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
|
is True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
key = key.model_dump() # noqa
|
key = key.model_dump() # noqa
|
||||||
except:
|
except Exception:
|
||||||
# if using pydantic v1
|
# if using pydantic v1
|
||||||
key = key.dict()
|
key = key.dict()
|
||||||
if (
|
if (
|
||||||
|
@ -428,29 +442,29 @@ async def user_info(
|
||||||
and key["team_id"] is not None
|
and key["team_id"] is not None
|
||||||
and key["team_id"] != "litellm-dashboard"
|
and key["team_id"] != "litellm-dashboard"
|
||||||
):
|
):
|
||||||
team_info = await prisma_client.get_data(
|
team_info = get_team_from_list(
|
||||||
team_id=key["team_id"], table_name="team"
|
team_list=teams_1, team_id=key["team_id"]
|
||||||
)
|
)
|
||||||
|
if team_info is not None:
|
||||||
team_alias = getattr(team_info, "team_alias", None)
|
team_alias = getattr(team_info, "team_alias", None)
|
||||||
key["team_alias"] = team_alias
|
key["team_alias"] = team_alias
|
||||||
|
else:
|
||||||
|
key["team_alias"] = None
|
||||||
else:
|
else:
|
||||||
key["team_alias"] = "None"
|
key["team_alias"] = "None"
|
||||||
returned_keys.append(key)
|
returned_keys.append(key)
|
||||||
|
|
||||||
response_data = {
|
response_data = UserInfoResponse(
|
||||||
"user_id": user_id,
|
user_id=user_id, user_info=user_info, keys=returned_keys, teams=team_list
|
||||||
"user_info": user_info,
|
)
|
||||||
"keys": returned_keys,
|
|
||||||
"teams": team_list,
|
|
||||||
}
|
|
||||||
return response_data
|
return response_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.exception(
|
||||||
"litellm.proxy.proxy_server.user_info(): Exception occured - {}".format(
|
"litellm.proxy.proxy_server.user_info(): Exception occured - {}".format(
|
||||||
str(e)
|
str(e)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(traceback.format_exc())
|
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||||
|
|
|
@ -248,7 +248,7 @@ from litellm.secret_managers.aws_secret_manager import (
|
||||||
load_aws_secret_manager,
|
load_aws_secret_manager,
|
||||||
)
|
)
|
||||||
from litellm.secret_managers.google_kms import load_google_kms
|
from litellm.secret_managers.google_kms import load_google_kms
|
||||||
from litellm.secret_managers.main import get_secret, str_to_bool
|
from litellm.secret_managers.main import get_secret, get_secret_str, str_to_bool
|
||||||
from litellm.types.llms.anthropic import (
|
from litellm.types.llms.anthropic import (
|
||||||
AnthropicMessagesRequest,
|
AnthropicMessagesRequest,
|
||||||
AnthropicResponse,
|
AnthropicResponse,
|
||||||
|
@ -2728,9 +2728,21 @@ async def startup_event():
|
||||||
|
|
||||||
### LOAD CONFIG ###
|
### LOAD CONFIG ###
|
||||||
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
||||||
|
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
|
||||||
verbose_proxy_logger.debug("worker_config: %s", worker_config)
|
verbose_proxy_logger.debug("worker_config: %s", worker_config)
|
||||||
# check if it's a valid file path
|
# check if it's a valid file path
|
||||||
if worker_config is not None:
|
if env_config_yaml is not None:
|
||||||
|
if os.path.isfile(env_config_yaml) and proxy_config.is_yaml(
|
||||||
|
config_file_path=env_config_yaml
|
||||||
|
):
|
||||||
|
(
|
||||||
|
llm_router,
|
||||||
|
llm_model_list,
|
||||||
|
general_settings,
|
||||||
|
) = await proxy_config.load_config(
|
||||||
|
router=llm_router, config_file_path=env_config_yaml
|
||||||
|
)
|
||||||
|
elif worker_config is not None:
|
||||||
if (
|
if (
|
||||||
isinstance(worker_config, str)
|
isinstance(worker_config, str)
|
||||||
and os.path.isfile(worker_config)
|
and os.path.isfile(worker_config)
|
||||||
|
|
|
@ -31,6 +31,7 @@ from litellm import (
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
_gemini_convert_messages_with_history,
|
_gemini_convert_messages_with_history,
|
||||||
)
|
)
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
|
||||||
from litellm.tests.test_streaming import streaming_format_tests
|
from litellm.tests.test_streaming import streaming_format_tests
|
||||||
|
|
||||||
litellm.num_retries = 3
|
litellm.num_retries = 3
|
||||||
|
@ -2869,3 +2870,24 @@ def test_gemini_finetuned_endpoint(base_model, metadata):
|
||||||
assert mock_client.call_args.kwargs["url"].endswith(
|
assert mock_client.call_args.kwargs["url"].endswith(
|
||||||
"endpoints/4965075652664360960:generateContent"
|
"endpoints/4965075652664360960:generateContent"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("api_base", ["", None, "my-custom-proxy-base"])
|
||||||
|
def test_custom_api_base(api_base):
|
||||||
|
stream = None
|
||||||
|
test_endpoint = "my-fake-endpoint"
|
||||||
|
vertex_base = VertexBase()
|
||||||
|
auth_header, url = vertex_base._check_custom_proxy(
|
||||||
|
api_base=api_base,
|
||||||
|
custom_llm_provider="gemini",
|
||||||
|
gemini_api_key="12324",
|
||||||
|
endpoint="",
|
||||||
|
stream=stream,
|
||||||
|
auth_header=None,
|
||||||
|
url="my-fake-endpoint",
|
||||||
|
)
|
||||||
|
|
||||||
|
if api_base:
|
||||||
|
assert url == api_base + ":"
|
||||||
|
else:
|
||||||
|
assert url == test_endpoint
|
||||||
|
|
|
@ -1933,6 +1933,41 @@ async def test_openai_compatible_custom_api_base(provider):
|
||||||
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_litellm_gateway_from_sdk():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello world",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
openai_client = OpenAI(api_key="fake-key")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
openai_client.chat.completions, "create", new=MagicMock()
|
||||||
|
) as mock_call:
|
||||||
|
try:
|
||||||
|
completion(
|
||||||
|
model="litellm_proxy/my-vllm-model",
|
||||||
|
messages=messages,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
client=openai_client,
|
||||||
|
api_base="my-custom-api-base",
|
||||||
|
hello="world",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_call.assert_called_once()
|
||||||
|
|
||||||
|
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
|
||||||
|
|
||||||
|
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
||||||
|
|
||||||
|
|
||||||
# ################### Hugging Face Conversational models ########################
|
# ################### Hugging Face Conversational models ########################
|
||||||
# def hf_test_completion_conv():
|
# def hf_test_completion_conv():
|
||||||
# try:
|
# try:
|
||||||
|
|
|
@ -256,7 +256,7 @@ def test_generate_and_call_with_valid_key(prisma_client, api_route):
|
||||||
|
|
||||||
# check /user/info to verify user_role was set correctly
|
# check /user/info to verify user_role was set correctly
|
||||||
new_user_info = await user_info(user_id=user_id)
|
new_user_info = await user_info(user_id=user_id)
|
||||||
new_user_info = new_user_info["user_info"]
|
new_user_info = new_user_info.user_info
|
||||||
print("new_user_info=", new_user_info)
|
print("new_user_info=", new_user_info)
|
||||||
assert new_user_info.user_role == LitellmUserRoles.INTERNAL_USER
|
assert new_user_info.user_role == LitellmUserRoles.INTERNAL_USER
|
||||||
assert new_user_info.user_id == user_id
|
assert new_user_info.user_id == user_id
|
||||||
|
|
|
@ -2154,7 +2154,7 @@ def test_openai_chat_completion_complete_response_call():
|
||||||
# test_openai_chat_completion_complete_response_call()
|
# test_openai_chat_completion_complete_response_call()
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307"], #
|
["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307", "o1-preview"], #
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"sync",
|
"sync",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue