fix(proxy_server.py): fix get model info when litellm_model_id is set + move model analytics to free (#7886)

* fix(proxy_server.py): fix get model info when litellm_model_id is set

Fixes https://github.com/BerriAI/litellm/issues/7873

* test(test_models.py): add test to ensure get model info on specific deployment has same value as all model info

Fixes https://github.com/BerriAI/litellm/issues/7873

* fix(usage.tsx): make model analytics free

Fixes @iqballx's feedback

* fix(fix(invoke_handler.py):-fix-bedrock-error-chunk-parsing): return correct bedrock status code and error message if chunk in stream

Improves bedrock stream error handling

* fix(proxy_server.py): fix linting errors

* test(test_auth_checks.py): remove redundant test

* fix(proxy_server.py): fix linting errors

* test: fix flaky test

* test: fix test
This commit is contained in:
Krish Dholakia 2025-01-21 08:19:07 -08:00 committed by GitHub
parent 0295f494b6
commit c8aa876785
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 146 additions and 131 deletions

View file

@ -894,7 +894,7 @@ class BedrockLLM(BaseAWSLLM):
if response.status_code != 200: if response.status_code != 200:
raise BedrockError( raise BedrockError(
status_code=response.status_code, message=response.read() status_code=response.status_code, message=str(response.read())
) )
decoder = AWSEventStreamDecoder(model=model) decoder = AWSEventStreamDecoder(model=model)
@ -1247,7 +1247,23 @@ class AWSEventStreamDecoder:
parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200: if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}") decoded_body = response_dict["body"].decode()
if isinstance(decoded_body, dict):
error_message = decoded_body.get("message")
elif isinstance(decoded_body, str):
error_message = decoded_body
else:
error_message = ""
exception_status = response_dict["headers"].get(":exception-type")
error_message = exception_status + " " + error_message
raise BedrockError(
status_code=response_dict["status_code"],
message=(
json.dumps(error_message)
if isinstance(error_message, dict)
else error_message
),
)
if "chunk" in parsed_response: if "chunk" in parsed_response:
chunk = parsed_response.get("chunk") chunk = parsed_response.get("chunk")
if not chunk: if not chunk:

View file

@ -432,6 +432,7 @@ class Huggingface(BaseLLM):
embed_url: str, embed_url: str,
) -> dict: ) -> dict:
data: Dict = {} data: Dict = {}
## TRANSFORMATION ## ## TRANSFORMATION ##
if "sentence-transformers" in model: if "sentence-transformers" in model:
if len(input) == 0: if len(input) == 0:

View file

@ -6,8 +6,7 @@ model_list:
api_base: https://exampleopenaiendpoint-production.up.railway.app api_base: https://exampleopenaiendpoint-production.up.railway.app
- model_name: openai-o1 - model_name: openai-o1
litellm_params: litellm_params:
model: openai/random_sleep model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
api_key: sk-1234
api_base: http://0.0.0.0:8090 api_base: http://0.0.0.0:8090
timeout: 2 timeout: 2
num_retries: 0 num_retries: 0

View file

@ -274,6 +274,7 @@ from litellm.types.llms.anthropic import (
AnthropicResponseUsageBlock, AnthropicResponseUsageBlock,
) )
from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import DeploymentTypedDict
from litellm.types.router import ModelInfo as RouterModelInfo from litellm.types.router import ModelInfo as RouterModelInfo
from litellm.types.router import RouterGeneralSettings, updateDeployment from litellm.types.router import RouterGeneralSettings, updateDeployment
from litellm.types.utils import CustomHuggingfaceTokenizer from litellm.types.utils import CustomHuggingfaceTokenizer
@ -6510,6 +6511,47 @@ async def model_metrics_exceptions(
return {"data": response, "exception_types": list(exception_types)} return {"data": response, "exception_types": list(exception_types)}
def _get_proxy_model_info(model: dict) -> dict:
# provided model_info in config.yaml
model_info = model.get("model_info", {})
# read litellm model_prices_and_context_window.json to get the following:
# input_cost_per_token, output_cost_per_token, max_tokens
litellm_model_info = get_litellm_model_info(model=model)
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
if litellm_model_info == {}:
# use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None)
try:
litellm_model_info = litellm.get_model_info(model=litellm_model)
except Exception:
litellm_model_info = {}
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
if litellm_model_info == {}:
# use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None)
split_model = litellm_model.split("/")
if len(split_model) > 0:
litellm_model = split_model[-1]
try:
litellm_model_info = litellm.get_model_info(
model=litellm_model, custom_llm_provider=split_model[0]
)
except Exception:
litellm_model_info = {}
for k, v in litellm_model_info.items():
if k not in model_info:
model_info[k] = v
model["model_info"] = model_info
# don't return the llm credentials
model = remove_sensitive_info_from_deployment(deployment_dict=model)
return model
@router.get( @router.get(
"/model/info", "/model/info",
tags=["model management"], tags=["model management"],
@ -6598,16 +6640,15 @@ async def model_info_v1( # noqa: PLR0915
deployment_info = llm_router.get_deployment(model_id=litellm_model_id) deployment_info = llm_router.get_deployment(model_id=litellm_model_id)
if deployment_info is None: if deployment_info is None:
raise HTTPException( raise HTTPException(
status_code=404, status_code=400,
detail={ detail={
"error": f"Model id = {litellm_model_id} not found on litellm proxy" "error": f"Model id = {litellm_model_id} not found on litellm proxy"
}, },
) )
_deployment_info_dict = deployment_info.model_dump() _deployment_info_dict = _get_proxy_model_info(
_deployment_info_dict = remove_sensitive_info_from_deployment( model=deployment_info.model_dump(exclude_none=True)
deployment_dict=_deployment_info_dict
) )
return {"data": _deployment_info_dict} return {"data": [_deployment_info_dict]}
all_models: List[dict] = [] all_models: List[dict] = []
model_access_groups: Dict[str, List[str]] = defaultdict(list) model_access_groups: Dict[str, List[str]] = defaultdict(list)
@ -6647,42 +6688,7 @@ async def model_info_v1( # noqa: PLR0915
all_models = [] all_models = []
for model in all_models: for model in all_models:
# provided model_info in config.yaml model = _get_proxy_model_info(model=model)
model_info = model.get("model_info", {})
# read litellm model_prices_and_context_window.json to get the following:
# input_cost_per_token, output_cost_per_token, max_tokens
litellm_model_info = get_litellm_model_info(model=model)
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
if litellm_model_info == {}:
# use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None)
try:
litellm_model_info = litellm.get_model_info(model=litellm_model)
except Exception:
litellm_model_info = {}
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
if litellm_model_info == {}:
# use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None)
split_model = litellm_model.split("/")
if len(split_model) > 0:
litellm_model = split_model[-1]
try:
litellm_model_info = litellm.get_model_info(
model=litellm_model, custom_llm_provider=split_model[0]
)
except Exception:
litellm_model_info = {}
for k, v in litellm_model_info.items():
if k not in model_info:
model_info[k] = v
model["model_info"] = model_info
# don't return the llm credentials
model = remove_sensitive_info_from_deployment(deployment_dict=model)
verbose_proxy_logger.debug("all_models: %s", all_models) verbose_proxy_logger.debug("all_models: %s", all_models)
return {"data": all_models} return {"data": all_models}

View file

@ -2429,3 +2429,33 @@ async def test_bedrock_image_url_sync_client():
except Exception as e: except Exception as e:
print(e) print(e)
mock_post.assert_called_once() mock_post.assert_called_once()
def test_bedrock_error_handling_streaming():
from litellm.llms.bedrock.chat.invoke_handler import (
AWSEventStreamDecoder,
BedrockError,
)
from unittest.mock import patch, Mock
event = Mock()
event.to_response_dict = Mock(
return_value={
"status_code": 400,
"headers": {
":exception-type": "serviceUnavailableException",
":content-type": "application/json",
":message-type": "exception",
},
"body": b'{"message":"Bedrock is unable to process your request."}',
}
)
decoder = AWSEventStreamDecoder(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
)
with pytest.raises(Exception) as e:
decoder._parse_message_from_event(event)
assert isinstance(e.value, BedrockError)
assert "Bedrock is unable to process your request." in e.value.message
assert e.value.status_code == 400

View file

@ -642,19 +642,27 @@ def tgi_mock_post(*args, **kwargs):
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_hf_embedding_sentence_sim(sync_mode): @patch("litellm.llms.huggingface.chat.handler.async_get_hf_task_embedding_for_model")
@patch("litellm.llms.huggingface.chat.handler.get_hf_task_embedding_for_model")
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_hf_embedding_sentence_sim(
mock_async_get_hf_task_embedding_for_model,
mock_get_hf_task_embedding_for_model,
sync_mode,
):
try: try:
# huggingface/microsoft/codebert-base # huggingface/microsoft/codebert-base
# huggingface/facebook/bart-large # huggingface/facebook/bart-large
mock_get_hf_task_embedding_for_model.return_value = "sentence-similarity"
mock_async_get_hf_task_embedding_for_model.return_value = "sentence-similarity"
if sync_mode is True: if sync_mode is True:
client = HTTPHandler(concurrent_limit=1) client = HTTPHandler(concurrent_limit=1)
else: else:
client = AsyncHTTPHandler(concurrent_limit=1) client = AsyncHTTPHandler(concurrent_limit=1)
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client: with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
data = { data = {
"model": "huggingface/TaylorAI/bge-micro-v2", "model": "huggingface/sentence-transformers/TaylorAI/bge-micro-v2",
"input": ["good morning from litellm", "this is another item"], "input": ["good morning from litellm", "this is another item"],
"client": client, "client": client,
} }

View file

@ -88,11 +88,14 @@ async def add_models(session, model_id="123", model_name="azure-gpt-3.5"):
return response_json return response_json
async def get_model_info(session, key): async def get_model_info(session, key, litellm_model_id=None):
""" """
Make sure only models user has access to are returned Make sure only models user has access to are returned
""" """
url = "http://0.0.0.0:4000/model/info" if litellm_model_id:
url = f"http://0.0.0.0:4000/model/info?litellm_model_id={litellm_model_id}"
else:
url = "http://0.0.0.0:4000/model/info"
headers = { headers = {
"Authorization": f"Bearer {key}", "Authorization": f"Bearer {key}",
"Content-Type": "application/json", "Content-Type": "application/json",
@ -148,6 +151,35 @@ async def test_get_models():
assert m == "gpt-4" assert m == "gpt-4"
@pytest.mark.asyncio
async def test_get_specific_model():
"""
Return specific model info
Ensure value of model_info is same as on `/model/info` (no id set)
"""
async with aiohttp.ClientSession() as session:
key_gen = await generate_key(session=session, models=["gpt-4"])
key = key_gen["key"]
response = await get_model_info(session=session, key=key)
models = [m["model_name"] for m in response["data"]]
model_specific_info = None
for idx, m in enumerate(models):
assert m == "gpt-4"
litellm_model_id = response["data"][idx]["model_info"]["id"]
model_specific_info = response["data"][idx]
assert litellm_model_id is not None
response = await get_model_info(
session=session, key=key, litellm_model_id=litellm_model_id
)
assert response["data"][0]["model_info"]["id"] == litellm_model_id
assert (
response["data"][0] == model_specific_info
), "Model info is not the same. Got={}, Expected={}".format(
response["data"][0], model_specific_info
)
async def delete_model(session, model_id="123"): async def delete_model(session, model_id="123"):
""" """
Make sure only models user has access to are returned Make sure only models user has access to are returned

View file

@ -554,10 +554,8 @@ const UsagePage: React.FC<UsagePageProps> = ({
</Col> </Col>
<Col numColSpan={2}> <Col numColSpan={2}>
<Card className="mb-2"> <Card className="mb-2">
<Title> Spend by Provider</Title> <Title>Spend by Provider</Title>
{ <>
premiumUser ? (
<>
<Grid numItems={2}> <Grid numItems={2}>
<Col numColSpan={1}> <Col numColSpan={1}>
<DonutChart <DonutChart
@ -592,17 +590,6 @@ const UsagePage: React.FC<UsagePageProps> = ({
</Col> </Col>
</Grid> </Grid>
</> </>
) : (
<div>
<p className="mb-2 text-gray-500 italic text-[12px]">Upgrade to use this feature</p>
<Button variant="primary" className="mb-2">
<a href="https://forms.gle/W3U4PZpJGFHWtHyA9" target="_blank">
Get Free Trial
</a>
</Button>
</div>
)
}
</Card> </Card>
</Col> </Col>
@ -643,9 +630,7 @@ const UsagePage: React.FC<UsagePageProps> = ({
</Card> </Card>
{ <>
premiumUser ? (
<>
{globalActivityPerModel.map((globalActivity, index) => ( {globalActivityPerModel.map((globalActivity, index) => (
<Card key={index}> <Card key={index}>
<Title>{globalActivity.model}</Title> <Title>{globalActivity.model}</Title>
@ -677,69 +662,7 @@ const UsagePage: React.FC<UsagePageProps> = ({
</Grid> </Grid>
</Card> </Card>
))} ))}
</> </>
) :
<>
{globalActivityPerModel && globalActivityPerModel.length > 0 &&
globalActivityPerModel.slice(0, 1).map((globalActivity, index) => (
<Card key={index}>
<Title> Activity by Model</Title>
<p className="mb-2 text-gray-500 italic text-[12px]">Upgrade to see analytics for all models</p>
<Button variant="primary" className="mb-2">
<a href="https://forms.gle/W3U4PZpJGFHWtHyA9" target="_blank">
Get Free Trial
</a>
</Button>
<Card>
<Title>{globalActivity.model}</Title>
<Grid numItems={2}>
<Col>
<Subtitle
style={{
fontSize: "15px",
fontWeight: "normal",
color: "#535452",
}}
>
API Requests {valueFormatterNumbers(globalActivity.sum_api_requests)}
</Subtitle>
<AreaChart
className="h-40"
data={globalActivity.daily_data}
index="date"
colors={['cyan']}
categories={['api_requests']}
valueFormatter={valueFormatterNumbers}
onValueChange={(v) => console.log(v)}
/>
</Col>
<Col>
<Subtitle
style={{
fontSize: "15px",
fontWeight: "normal",
color: "#535452",
}}
>
Tokens {valueFormatterNumbers(globalActivity.sum_total_tokens)}
</Subtitle>
<BarChart
className="h-40"
data={globalActivity.daily_data}
index="date"
colors={['cyan']}
valueFormatter={valueFormatterNumbers}
categories={['total_tokens']}
onValueChange={(v) => console.log(v)}
/>
</Col>
</Grid>
</Card>
</Card>
))}
</>
}
</Grid> </Grid>
</TabPanel> </TabPanel>
</TabPanels> </TabPanels>