mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(proxy_server.py): fix get model info when litellm_model_id is set + move model analytics to free (#7886)
* fix(proxy_server.py): fix get model info when litellm_model_id is set Fixes https://github.com/BerriAI/litellm/issues/7873 * test(test_models.py): add test to ensure get model info on specific deployment has same value as all model info Fixes https://github.com/BerriAI/litellm/issues/7873 * fix(usage.tsx): make model analytics free Fixes @iqballx's feedback * fix(fix(invoke_handler.py):-fix-bedrock-error-chunk-parsing): return correct bedrock status code and error message if chunk in stream Improves bedrock stream error handling * fix(proxy_server.py): fix linting errors * test(test_auth_checks.py): remove redundant test * fix(proxy_server.py): fix linting errors * test: fix flaky test * test: fix test
This commit is contained in:
parent
0295f494b6
commit
c8aa876785
8 changed files with 146 additions and 131 deletions
|
@ -894,7 +894,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise BedrockError(
|
raise BedrockError(
|
||||||
status_code=response.status_code, message=response.read()
|
status_code=response.status_code, message=str(response.read())
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder = AWSEventStreamDecoder(model=model)
|
decoder = AWSEventStreamDecoder(model=model)
|
||||||
|
@ -1247,7 +1247,23 @@ class AWSEventStreamDecoder:
|
||||||
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
||||||
|
|
||||||
if response_dict["status_code"] != 200:
|
if response_dict["status_code"] != 200:
|
||||||
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
decoded_body = response_dict["body"].decode()
|
||||||
|
if isinstance(decoded_body, dict):
|
||||||
|
error_message = decoded_body.get("message")
|
||||||
|
elif isinstance(decoded_body, str):
|
||||||
|
error_message = decoded_body
|
||||||
|
else:
|
||||||
|
error_message = ""
|
||||||
|
exception_status = response_dict["headers"].get(":exception-type")
|
||||||
|
error_message = exception_status + " " + error_message
|
||||||
|
raise BedrockError(
|
||||||
|
status_code=response_dict["status_code"],
|
||||||
|
message=(
|
||||||
|
json.dumps(error_message)
|
||||||
|
if isinstance(error_message, dict)
|
||||||
|
else error_message
|
||||||
|
),
|
||||||
|
)
|
||||||
if "chunk" in parsed_response:
|
if "chunk" in parsed_response:
|
||||||
chunk = parsed_response.get("chunk")
|
chunk = parsed_response.get("chunk")
|
||||||
if not chunk:
|
if not chunk:
|
||||||
|
|
|
@ -432,6 +432,7 @@ class Huggingface(BaseLLM):
|
||||||
embed_url: str,
|
embed_url: str,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
data: Dict = {}
|
data: Dict = {}
|
||||||
|
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
if "sentence-transformers" in model:
|
if "sentence-transformers" in model:
|
||||||
if len(input) == 0:
|
if len(input) == 0:
|
||||||
|
|
|
@ -6,8 +6,7 @@ model_list:
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||||
- model_name: openai-o1
|
- model_name: openai-o1
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/random_sleep
|
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
|
||||||
api_key: sk-1234
|
|
||||||
api_base: http://0.0.0.0:8090
|
api_base: http://0.0.0.0:8090
|
||||||
timeout: 2
|
timeout: 2
|
||||||
num_retries: 0
|
num_retries: 0
|
||||||
|
|
|
@ -274,6 +274,7 @@ from litellm.types.llms.anthropic import (
|
||||||
AnthropicResponseUsageBlock,
|
AnthropicResponseUsageBlock,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||||
|
from litellm.types.router import DeploymentTypedDict
|
||||||
from litellm.types.router import ModelInfo as RouterModelInfo
|
from litellm.types.router import ModelInfo as RouterModelInfo
|
||||||
from litellm.types.router import RouterGeneralSettings, updateDeployment
|
from litellm.types.router import RouterGeneralSettings, updateDeployment
|
||||||
from litellm.types.utils import CustomHuggingfaceTokenizer
|
from litellm.types.utils import CustomHuggingfaceTokenizer
|
||||||
|
@ -6510,6 +6511,47 @@ async def model_metrics_exceptions(
|
||||||
return {"data": response, "exception_types": list(exception_types)}
|
return {"data": response, "exception_types": list(exception_types)}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_proxy_model_info(model: dict) -> dict:
|
||||||
|
# provided model_info in config.yaml
|
||||||
|
model_info = model.get("model_info", {})
|
||||||
|
|
||||||
|
# read litellm model_prices_and_context_window.json to get the following:
|
||||||
|
# input_cost_per_token, output_cost_per_token, max_tokens
|
||||||
|
litellm_model_info = get_litellm_model_info(model=model)
|
||||||
|
|
||||||
|
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
|
||||||
|
if litellm_model_info == {}:
|
||||||
|
# use litellm_param model_name to get model_info
|
||||||
|
litellm_params = model.get("litellm_params", {})
|
||||||
|
litellm_model = litellm_params.get("model", None)
|
||||||
|
try:
|
||||||
|
litellm_model_info = litellm.get_model_info(model=litellm_model)
|
||||||
|
except Exception:
|
||||||
|
litellm_model_info = {}
|
||||||
|
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
|
||||||
|
if litellm_model_info == {}:
|
||||||
|
# use litellm_param model_name to get model_info
|
||||||
|
litellm_params = model.get("litellm_params", {})
|
||||||
|
litellm_model = litellm_params.get("model", None)
|
||||||
|
split_model = litellm_model.split("/")
|
||||||
|
if len(split_model) > 0:
|
||||||
|
litellm_model = split_model[-1]
|
||||||
|
try:
|
||||||
|
litellm_model_info = litellm.get_model_info(
|
||||||
|
model=litellm_model, custom_llm_provider=split_model[0]
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
litellm_model_info = {}
|
||||||
|
for k, v in litellm_model_info.items():
|
||||||
|
if k not in model_info:
|
||||||
|
model_info[k] = v
|
||||||
|
model["model_info"] = model_info
|
||||||
|
# don't return the llm credentials
|
||||||
|
model = remove_sensitive_info_from_deployment(deployment_dict=model)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/model/info",
|
"/model/info",
|
||||||
tags=["model management"],
|
tags=["model management"],
|
||||||
|
@ -6598,16 +6640,15 @@ async def model_info_v1( # noqa: PLR0915
|
||||||
deployment_info = llm_router.get_deployment(model_id=litellm_model_id)
|
deployment_info = llm_router.get_deployment(model_id=litellm_model_id)
|
||||||
if deployment_info is None:
|
if deployment_info is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=400,
|
||||||
detail={
|
detail={
|
||||||
"error": f"Model id = {litellm_model_id} not found on litellm proxy"
|
"error": f"Model id = {litellm_model_id} not found on litellm proxy"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
_deployment_info_dict = deployment_info.model_dump()
|
_deployment_info_dict = _get_proxy_model_info(
|
||||||
_deployment_info_dict = remove_sensitive_info_from_deployment(
|
model=deployment_info.model_dump(exclude_none=True)
|
||||||
deployment_dict=_deployment_info_dict
|
|
||||||
)
|
)
|
||||||
return {"data": _deployment_info_dict}
|
return {"data": [_deployment_info_dict]}
|
||||||
|
|
||||||
all_models: List[dict] = []
|
all_models: List[dict] = []
|
||||||
model_access_groups: Dict[str, List[str]] = defaultdict(list)
|
model_access_groups: Dict[str, List[str]] = defaultdict(list)
|
||||||
|
@ -6647,42 +6688,7 @@ async def model_info_v1( # noqa: PLR0915
|
||||||
all_models = []
|
all_models = []
|
||||||
|
|
||||||
for model in all_models:
|
for model in all_models:
|
||||||
# provided model_info in config.yaml
|
model = _get_proxy_model_info(model=model)
|
||||||
model_info = model.get("model_info", {})
|
|
||||||
|
|
||||||
# read litellm model_prices_and_context_window.json to get the following:
|
|
||||||
# input_cost_per_token, output_cost_per_token, max_tokens
|
|
||||||
litellm_model_info = get_litellm_model_info(model=model)
|
|
||||||
|
|
||||||
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
|
|
||||||
if litellm_model_info == {}:
|
|
||||||
# use litellm_param model_name to get model_info
|
|
||||||
litellm_params = model.get("litellm_params", {})
|
|
||||||
litellm_model = litellm_params.get("model", None)
|
|
||||||
try:
|
|
||||||
litellm_model_info = litellm.get_model_info(model=litellm_model)
|
|
||||||
except Exception:
|
|
||||||
litellm_model_info = {}
|
|
||||||
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
|
|
||||||
if litellm_model_info == {}:
|
|
||||||
# use litellm_param model_name to get model_info
|
|
||||||
litellm_params = model.get("litellm_params", {})
|
|
||||||
litellm_model = litellm_params.get("model", None)
|
|
||||||
split_model = litellm_model.split("/")
|
|
||||||
if len(split_model) > 0:
|
|
||||||
litellm_model = split_model[-1]
|
|
||||||
try:
|
|
||||||
litellm_model_info = litellm.get_model_info(
|
|
||||||
model=litellm_model, custom_llm_provider=split_model[0]
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
litellm_model_info = {}
|
|
||||||
for k, v in litellm_model_info.items():
|
|
||||||
if k not in model_info:
|
|
||||||
model_info[k] = v
|
|
||||||
model["model_info"] = model_info
|
|
||||||
# don't return the llm credentials
|
|
||||||
model = remove_sensitive_info_from_deployment(deployment_dict=model)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug("all_models: %s", all_models)
|
verbose_proxy_logger.debug("all_models: %s", all_models)
|
||||||
return {"data": all_models}
|
return {"data": all_models}
|
||||||
|
|
|
@ -2429,3 +2429,33 @@ async def test_bedrock_image_url_sync_client():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
mock_post.assert_called_once()
|
mock_post.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_error_handling_streaming():
|
||||||
|
from litellm.llms.bedrock.chat.invoke_handler import (
|
||||||
|
AWSEventStreamDecoder,
|
||||||
|
BedrockError,
|
||||||
|
)
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
|
||||||
|
event = Mock()
|
||||||
|
event.to_response_dict = Mock(
|
||||||
|
return_value={
|
||||||
|
"status_code": 400,
|
||||||
|
"headers": {
|
||||||
|
":exception-type": "serviceUnavailableException",
|
||||||
|
":content-type": "application/json",
|
||||||
|
":message-type": "exception",
|
||||||
|
},
|
||||||
|
"body": b'{"message":"Bedrock is unable to process your request."}',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = AWSEventStreamDecoder(
|
||||||
|
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
|
||||||
|
)
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
decoder._parse_message_from_event(event)
|
||||||
|
assert isinstance(e.value, BedrockError)
|
||||||
|
assert "Bedrock is unable to process your request." in e.value.message
|
||||||
|
assert e.value.status_code == 400
|
||||||
|
|
|
@ -642,19 +642,27 @@ def tgi_mock_post(*args, **kwargs):
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_hf_embedding_sentence_sim(sync_mode):
|
@patch("litellm.llms.huggingface.chat.handler.async_get_hf_task_embedding_for_model")
|
||||||
|
@patch("litellm.llms.huggingface.chat.handler.get_hf_task_embedding_for_model")
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
async def test_hf_embedding_sentence_sim(
|
||||||
|
mock_async_get_hf_task_embedding_for_model,
|
||||||
|
mock_get_hf_task_embedding_for_model,
|
||||||
|
sync_mode,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
# huggingface/microsoft/codebert-base
|
# huggingface/microsoft/codebert-base
|
||||||
# huggingface/facebook/bart-large
|
# huggingface/facebook/bart-large
|
||||||
|
mock_get_hf_task_embedding_for_model.return_value = "sentence-similarity"
|
||||||
|
mock_async_get_hf_task_embedding_for_model.return_value = "sentence-similarity"
|
||||||
if sync_mode is True:
|
if sync_mode is True:
|
||||||
client = HTTPHandler(concurrent_limit=1)
|
client = HTTPHandler(concurrent_limit=1)
|
||||||
else:
|
else:
|
||||||
client = AsyncHTTPHandler(concurrent_limit=1)
|
client = AsyncHTTPHandler(concurrent_limit=1)
|
||||||
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
|
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
|
||||||
data = {
|
data = {
|
||||||
"model": "huggingface/TaylorAI/bge-micro-v2",
|
"model": "huggingface/sentence-transformers/TaylorAI/bge-micro-v2",
|
||||||
"input": ["good morning from litellm", "this is another item"],
|
"input": ["good morning from litellm", "this is another item"],
|
||||||
"client": client,
|
"client": client,
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,11 +88,14 @@ async def add_models(session, model_id="123", model_name="azure-gpt-3.5"):
|
||||||
return response_json
|
return response_json
|
||||||
|
|
||||||
|
|
||||||
async def get_model_info(session, key):
|
async def get_model_info(session, key, litellm_model_id=None):
|
||||||
"""
|
"""
|
||||||
Make sure only models user has access to are returned
|
Make sure only models user has access to are returned
|
||||||
"""
|
"""
|
||||||
url = "http://0.0.0.0:4000/model/info"
|
if litellm_model_id:
|
||||||
|
url = f"http://0.0.0.0:4000/model/info?litellm_model_id={litellm_model_id}"
|
||||||
|
else:
|
||||||
|
url = "http://0.0.0.0:4000/model/info"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {key}",
|
"Authorization": f"Bearer {key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
@ -148,6 +151,35 @@ async def test_get_models():
|
||||||
assert m == "gpt-4"
|
assert m == "gpt-4"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_specific_model():
|
||||||
|
"""
|
||||||
|
Return specific model info
|
||||||
|
|
||||||
|
Ensure value of model_info is same as on `/model/info` (no id set)
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
key_gen = await generate_key(session=session, models=["gpt-4"])
|
||||||
|
key = key_gen["key"]
|
||||||
|
response = await get_model_info(session=session, key=key)
|
||||||
|
models = [m["model_name"] for m in response["data"]]
|
||||||
|
model_specific_info = None
|
||||||
|
for idx, m in enumerate(models):
|
||||||
|
assert m == "gpt-4"
|
||||||
|
litellm_model_id = response["data"][idx]["model_info"]["id"]
|
||||||
|
model_specific_info = response["data"][idx]
|
||||||
|
assert litellm_model_id is not None
|
||||||
|
response = await get_model_info(
|
||||||
|
session=session, key=key, litellm_model_id=litellm_model_id
|
||||||
|
)
|
||||||
|
assert response["data"][0]["model_info"]["id"] == litellm_model_id
|
||||||
|
assert (
|
||||||
|
response["data"][0] == model_specific_info
|
||||||
|
), "Model info is not the same. Got={}, Expected={}".format(
|
||||||
|
response["data"][0], model_specific_info
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def delete_model(session, model_id="123"):
|
async def delete_model(session, model_id="123"):
|
||||||
"""
|
"""
|
||||||
Make sure only models user has access to are returned
|
Make sure only models user has access to are returned
|
||||||
|
|
|
@ -554,10 +554,8 @@ const UsagePage: React.FC<UsagePageProps> = ({
|
||||||
</Col>
|
</Col>
|
||||||
<Col numColSpan={2}>
|
<Col numColSpan={2}>
|
||||||
<Card className="mb-2">
|
<Card className="mb-2">
|
||||||
<Title>✨ Spend by Provider</Title>
|
<Title>Spend by Provider</Title>
|
||||||
{
|
<>
|
||||||
premiumUser ? (
|
|
||||||
<>
|
|
||||||
<Grid numItems={2}>
|
<Grid numItems={2}>
|
||||||
<Col numColSpan={1}>
|
<Col numColSpan={1}>
|
||||||
<DonutChart
|
<DonutChart
|
||||||
|
@ -592,17 +590,6 @@ const UsagePage: React.FC<UsagePageProps> = ({
|
||||||
</Col>
|
</Col>
|
||||||
</Grid>
|
</Grid>
|
||||||
</>
|
</>
|
||||||
) : (
|
|
||||||
<div>
|
|
||||||
<p className="mb-2 text-gray-500 italic text-[12px]">Upgrade to use this feature</p>
|
|
||||||
<Button variant="primary" className="mb-2">
|
|
||||||
<a href="https://forms.gle/W3U4PZpJGFHWtHyA9" target="_blank">
|
|
||||||
Get Free Trial
|
|
||||||
</a>
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
</Card>
|
</Card>
|
||||||
</Col>
|
</Col>
|
||||||
|
@ -643,9 +630,7 @@ const UsagePage: React.FC<UsagePageProps> = ({
|
||||||
|
|
||||||
</Card>
|
</Card>
|
||||||
|
|
||||||
{
|
<>
|
||||||
premiumUser ? (
|
|
||||||
<>
|
|
||||||
{globalActivityPerModel.map((globalActivity, index) => (
|
{globalActivityPerModel.map((globalActivity, index) => (
|
||||||
<Card key={index}>
|
<Card key={index}>
|
||||||
<Title>{globalActivity.model}</Title>
|
<Title>{globalActivity.model}</Title>
|
||||||
|
@ -677,69 +662,7 @@ const UsagePage: React.FC<UsagePageProps> = ({
|
||||||
</Grid>
|
</Grid>
|
||||||
</Card>
|
</Card>
|
||||||
))}
|
))}
|
||||||
</>
|
</>
|
||||||
) :
|
|
||||||
<>
|
|
||||||
{globalActivityPerModel && globalActivityPerModel.length > 0 &&
|
|
||||||
globalActivityPerModel.slice(0, 1).map((globalActivity, index) => (
|
|
||||||
<Card key={index}>
|
|
||||||
<Title>✨ Activity by Model</Title>
|
|
||||||
<p className="mb-2 text-gray-500 italic text-[12px]">Upgrade to see analytics for all models</p>
|
|
||||||
<Button variant="primary" className="mb-2">
|
|
||||||
<a href="https://forms.gle/W3U4PZpJGFHWtHyA9" target="_blank">
|
|
||||||
Get Free Trial
|
|
||||||
</a>
|
|
||||||
</Button>
|
|
||||||
<Card>
|
|
||||||
<Title>{globalActivity.model}</Title>
|
|
||||||
<Grid numItems={2}>
|
|
||||||
<Col>
|
|
||||||
<Subtitle
|
|
||||||
style={{
|
|
||||||
fontSize: "15px",
|
|
||||||
fontWeight: "normal",
|
|
||||||
color: "#535452",
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
API Requests {valueFormatterNumbers(globalActivity.sum_api_requests)}
|
|
||||||
</Subtitle>
|
|
||||||
<AreaChart
|
|
||||||
className="h-40"
|
|
||||||
data={globalActivity.daily_data}
|
|
||||||
index="date"
|
|
||||||
colors={['cyan']}
|
|
||||||
categories={['api_requests']}
|
|
||||||
valueFormatter={valueFormatterNumbers}
|
|
||||||
onValueChange={(v) => console.log(v)}
|
|
||||||
/>
|
|
||||||
</Col>
|
|
||||||
<Col>
|
|
||||||
<Subtitle
|
|
||||||
style={{
|
|
||||||
fontSize: "15px",
|
|
||||||
fontWeight: "normal",
|
|
||||||
color: "#535452",
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
Tokens {valueFormatterNumbers(globalActivity.sum_total_tokens)}
|
|
||||||
</Subtitle>
|
|
||||||
<BarChart
|
|
||||||
className="h-40"
|
|
||||||
data={globalActivity.daily_data}
|
|
||||||
index="date"
|
|
||||||
colors={['cyan']}
|
|
||||||
valueFormatter={valueFormatterNumbers}
|
|
||||||
categories={['total_tokens']}
|
|
||||||
onValueChange={(v) => console.log(v)}
|
|
||||||
/>
|
|
||||||
</Col>
|
|
||||||
|
|
||||||
</Grid>
|
|
||||||
</Card>
|
|
||||||
</Card>
|
|
||||||
))}
|
|
||||||
</>
|
|
||||||
}
|
|
||||||
</Grid>
|
</Grid>
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
</TabPanels>
|
</TabPanels>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue