fix(proxy_server.py): fix get model info when litellm_model_id is set + move model analytics to free (#7886)

* fix(proxy_server.py): fix get model info when litellm_model_id is set

Fixes https://github.com/BerriAI/litellm/issues/7873

* test(test_models.py): add test to ensure get model info on specific deployment has same value as all model info

Fixes https://github.com/BerriAI/litellm/issues/7873

* fix(usage.tsx): make model analytics free

Fixes @iqballx's feedback

* fix(fix(invoke_handler.py):-fix-bedrock-error-chunk-parsing): return correct bedrock status code and error message if chunk in stream

Improves bedrock stream error handling

* fix(proxy_server.py): fix linting errors

* test(test_auth_checks.py): remove redundant test

* fix(proxy_server.py): fix linting errors

* test: fix flaky test

* test: fix test
This commit is contained in:
Krish Dholakia 2025-01-21 08:19:07 -08:00 committed by GitHub
parent 0295f494b6
commit c8aa876785
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 146 additions and 131 deletions

View file

@ -894,7 +894,7 @@ class BedrockLLM(BaseAWSLLM):
if response.status_code != 200:
raise BedrockError(
status_code=response.status_code, message=response.read()
status_code=response.status_code, message=str(response.read())
)
decoder = AWSEventStreamDecoder(model=model)
@ -1247,7 +1247,23 @@ class AWSEventStreamDecoder:
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")
decoded_body = response_dict["body"].decode()
if isinstance(decoded_body, dict):
error_message = decoded_body.get("message")
elif isinstance(decoded_body, str):
error_message = decoded_body
else:
error_message = ""
exception_status = response_dict["headers"].get(":exception-type")
error_message = exception_status + " " + error_message
raise BedrockError(
status_code=response_dict["status_code"],
message=(
json.dumps(error_message)
if isinstance(error_message, dict)
else error_message
),
)
if "chunk" in parsed_response:
chunk = parsed_response.get("chunk")
if not chunk:

View file

@ -432,6 +432,7 @@ class Huggingface(BaseLLM):
embed_url: str,
) -> dict:
data: Dict = {}
## TRANSFORMATION ##
if "sentence-transformers" in model:
if len(input) == 0:

View file

@ -6,8 +6,7 @@ model_list:
api_base: https://exampleopenaiendpoint-production.up.railway.app
- model_name: openai-o1
litellm_params:
model: openai/random_sleep
api_key: sk-1234
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
api_base: http://0.0.0.0:8090
timeout: 2
num_retries: 0

View file

@ -274,6 +274,7 @@ from litellm.types.llms.anthropic import (
AnthropicResponseUsageBlock,
)
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import DeploymentTypedDict
from litellm.types.router import ModelInfo as RouterModelInfo
from litellm.types.router import RouterGeneralSettings, updateDeployment
from litellm.types.utils import CustomHuggingfaceTokenizer
@ -6510,6 +6511,47 @@ async def model_metrics_exceptions(
return {"data": response, "exception_types": list(exception_types)}
def _get_proxy_model_info(model: dict) -> dict:
# provided model_info in config.yaml
model_info = model.get("model_info", {})
# read litellm model_prices_and_context_window.json to get the following:
# input_cost_per_token, output_cost_per_token, max_tokens
litellm_model_info = get_litellm_model_info(model=model)
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
if litellm_model_info == {}:
# use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None)
try:
litellm_model_info = litellm.get_model_info(model=litellm_model)
except Exception:
litellm_model_info = {}
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
if litellm_model_info == {}:
# use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None)
split_model = litellm_model.split("/")
if len(split_model) > 0:
litellm_model = split_model[-1]
try:
litellm_model_info = litellm.get_model_info(
model=litellm_model, custom_llm_provider=split_model[0]
)
except Exception:
litellm_model_info = {}
for k, v in litellm_model_info.items():
if k not in model_info:
model_info[k] = v
model["model_info"] = model_info
# don't return the llm credentials
model = remove_sensitive_info_from_deployment(deployment_dict=model)
return model
@router.get(
"/model/info",
tags=["model management"],
@ -6598,16 +6640,15 @@ async def model_info_v1( # noqa: PLR0915
deployment_info = llm_router.get_deployment(model_id=litellm_model_id)
if deployment_info is None:
raise HTTPException(
status_code=404,
status_code=400,
detail={
"error": f"Model id = {litellm_model_id} not found on litellm proxy"
},
)
_deployment_info_dict = deployment_info.model_dump()
_deployment_info_dict = remove_sensitive_info_from_deployment(
deployment_dict=_deployment_info_dict
_deployment_info_dict = _get_proxy_model_info(
model=deployment_info.model_dump(exclude_none=True)
)
return {"data": _deployment_info_dict}
return {"data": [_deployment_info_dict]}
all_models: List[dict] = []
model_access_groups: Dict[str, List[str]] = defaultdict(list)
@ -6647,42 +6688,7 @@ async def model_info_v1( # noqa: PLR0915
all_models = []
for model in all_models:
# provided model_info in config.yaml
model_info = model.get("model_info", {})
# read litellm model_prices_and_context_window.json to get the following:
# input_cost_per_token, output_cost_per_token, max_tokens
litellm_model_info = get_litellm_model_info(model=model)
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
if litellm_model_info == {}:
# use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None)
try:
litellm_model_info = litellm.get_model_info(model=litellm_model)
except Exception:
litellm_model_info = {}
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
if litellm_model_info == {}:
# use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None)
split_model = litellm_model.split("/")
if len(split_model) > 0:
litellm_model = split_model[-1]
try:
litellm_model_info = litellm.get_model_info(
model=litellm_model, custom_llm_provider=split_model[0]
)
except Exception:
litellm_model_info = {}
for k, v in litellm_model_info.items():
if k not in model_info:
model_info[k] = v
model["model_info"] = model_info
# don't return the llm credentials
model = remove_sensitive_info_from_deployment(deployment_dict=model)
model = _get_proxy_model_info(model=model)
verbose_proxy_logger.debug("all_models: %s", all_models)
return {"data": all_models}

View file

@ -2429,3 +2429,33 @@ async def test_bedrock_image_url_sync_client():
except Exception as e:
print(e)
mock_post.assert_called_once()
def test_bedrock_error_handling_streaming():
from litellm.llms.bedrock.chat.invoke_handler import (
AWSEventStreamDecoder,
BedrockError,
)
from unittest.mock import patch, Mock
event = Mock()
event.to_response_dict = Mock(
return_value={
"status_code": 400,
"headers": {
":exception-type": "serviceUnavailableException",
":content-type": "application/json",
":message-type": "exception",
},
"body": b'{"message":"Bedrock is unable to process your request."}',
}
)
decoder = AWSEventStreamDecoder(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
)
with pytest.raises(Exception) as e:
decoder._parse_message_from_event(event)
assert isinstance(e.value, BedrockError)
assert "Bedrock is unable to process your request." in e.value.message
assert e.value.status_code == 400

View file

@ -642,19 +642,27 @@ def tgi_mock_post(*args, **kwargs):
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_hf_embedding_sentence_sim(sync_mode):
@patch("litellm.llms.huggingface.chat.handler.async_get_hf_task_embedding_for_model")
@patch("litellm.llms.huggingface.chat.handler.get_hf_task_embedding_for_model")
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_hf_embedding_sentence_sim(
mock_async_get_hf_task_embedding_for_model,
mock_get_hf_task_embedding_for_model,
sync_mode,
):
try:
# huggingface/microsoft/codebert-base
# huggingface/facebook/bart-large
mock_get_hf_task_embedding_for_model.return_value = "sentence-similarity"
mock_async_get_hf_task_embedding_for_model.return_value = "sentence-similarity"
if sync_mode is True:
client = HTTPHandler(concurrent_limit=1)
else:
client = AsyncHTTPHandler(concurrent_limit=1)
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
data = {
"model": "huggingface/TaylorAI/bge-micro-v2",
"model": "huggingface/sentence-transformers/TaylorAI/bge-micro-v2",
"input": ["good morning from litellm", "this is another item"],
"client": client,
}

View file

@ -88,10 +88,13 @@ async def add_models(session, model_id="123", model_name="azure-gpt-3.5"):
return response_json
async def get_model_info(session, key):
async def get_model_info(session, key, litellm_model_id=None):
"""
Make sure only models user has access to are returned
"""
if litellm_model_id:
url = f"http://0.0.0.0:4000/model/info?litellm_model_id={litellm_model_id}"
else:
url = "http://0.0.0.0:4000/model/info"
headers = {
"Authorization": f"Bearer {key}",
@ -148,6 +151,35 @@ async def test_get_models():
assert m == "gpt-4"
@pytest.mark.asyncio
async def test_get_specific_model():
"""
Return specific model info
Ensure value of model_info is same as on `/model/info` (no id set)
"""
async with aiohttp.ClientSession() as session:
key_gen = await generate_key(session=session, models=["gpt-4"])
key = key_gen["key"]
response = await get_model_info(session=session, key=key)
models = [m["model_name"] for m in response["data"]]
model_specific_info = None
for idx, m in enumerate(models):
assert m == "gpt-4"
litellm_model_id = response["data"][idx]["model_info"]["id"]
model_specific_info = response["data"][idx]
assert litellm_model_id is not None
response = await get_model_info(
session=session, key=key, litellm_model_id=litellm_model_id
)
assert response["data"][0]["model_info"]["id"] == litellm_model_id
assert (
response["data"][0] == model_specific_info
), "Model info is not the same. Got={}, Expected={}".format(
response["data"][0], model_specific_info
)
async def delete_model(session, model_id="123"):
"""
Make sure only models user has access to are returned

View file

@ -554,9 +554,7 @@ const UsagePage: React.FC<UsagePageProps> = ({
</Col>
<Col numColSpan={2}>
<Card className="mb-2">
<Title> Spend by Provider</Title>
{
premiumUser ? (
<Title>Spend by Provider</Title>
<>
<Grid numItems={2}>
<Col numColSpan={1}>
@ -592,17 +590,6 @@ const UsagePage: React.FC<UsagePageProps> = ({
</Col>
</Grid>
</>
) : (
<div>
<p className="mb-2 text-gray-500 italic text-[12px]">Upgrade to use this feature</p>
<Button variant="primary" className="mb-2">
<a href="https://forms.gle/W3U4PZpJGFHWtHyA9" target="_blank">
Get Free Trial
</a>
</Button>
</div>
)
}
</Card>
</Col>
@ -643,8 +630,6 @@ const UsagePage: React.FC<UsagePageProps> = ({
</Card>
{
premiumUser ? (
<>
{globalActivityPerModel.map((globalActivity, index) => (
<Card key={index}>
@ -678,68 +663,6 @@ const UsagePage: React.FC<UsagePageProps> = ({
</Card>
))}
</>
) :
<>
{globalActivityPerModel && globalActivityPerModel.length > 0 &&
globalActivityPerModel.slice(0, 1).map((globalActivity, index) => (
<Card key={index}>
<Title> Activity by Model</Title>
<p className="mb-2 text-gray-500 italic text-[12px]">Upgrade to see analytics for all models</p>
<Button variant="primary" className="mb-2">
<a href="https://forms.gle/W3U4PZpJGFHWtHyA9" target="_blank">
Get Free Trial
</a>
</Button>
<Card>
<Title>{globalActivity.model}</Title>
<Grid numItems={2}>
<Col>
<Subtitle
style={{
fontSize: "15px",
fontWeight: "normal",
color: "#535452",
}}
>
API Requests {valueFormatterNumbers(globalActivity.sum_api_requests)}
</Subtitle>
<AreaChart
className="h-40"
data={globalActivity.daily_data}
index="date"
colors={['cyan']}
categories={['api_requests']}
valueFormatter={valueFormatterNumbers}
onValueChange={(v) => console.log(v)}
/>
</Col>
<Col>
<Subtitle
style={{
fontSize: "15px",
fontWeight: "normal",
color: "#535452",
}}
>
Tokens {valueFormatterNumbers(globalActivity.sum_total_tokens)}
</Subtitle>
<BarChart
className="h-40"
data={globalActivity.daily_data}
index="date"
colors={['cyan']}
valueFormatter={valueFormatterNumbers}
categories={['total_tokens']}
onValueChange={(v) => console.log(v)}
/>
</Col>
</Grid>
</Card>
</Card>
))}
</>
}
</Grid>
</TabPanel>
</TabPanels>