mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(proxy_server.py): fix get model info when litellm_model_id is set + move model analytics to free (#7886)
* fix(proxy_server.py): fix get model info when litellm_model_id is set Fixes https://github.com/BerriAI/litellm/issues/7873 * test(test_models.py): add test to ensure get model info on specific deployment has same value as all model info Fixes https://github.com/BerriAI/litellm/issues/7873 * fix(usage.tsx): make model analytics free Fixes @iqballx's feedback * fix(fix(invoke_handler.py):-fix-bedrock-error-chunk-parsing): return correct bedrock status code and error message if chunk in stream Improves bedrock stream error handling * fix(proxy_server.py): fix linting errors * test(test_auth_checks.py): remove redundant test * fix(proxy_server.py): fix linting errors * test: fix flaky test * test: fix test
This commit is contained in:
parent
0295f494b6
commit
c8aa876785
8 changed files with 146 additions and 131 deletions
|
@ -894,7 +894,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(
|
||||
status_code=response.status_code, message=response.read()
|
||||
status_code=response.status_code, message=str(response.read())
|
||||
)
|
||||
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
|
@ -1247,7 +1247,23 @@ class AWSEventStreamDecoder:
|
|||
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
||||
|
||||
if response_dict["status_code"] != 200:
|
||||
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
||||
decoded_body = response_dict["body"].decode()
|
||||
if isinstance(decoded_body, dict):
|
||||
error_message = decoded_body.get("message")
|
||||
elif isinstance(decoded_body, str):
|
||||
error_message = decoded_body
|
||||
else:
|
||||
error_message = ""
|
||||
exception_status = response_dict["headers"].get(":exception-type")
|
||||
error_message = exception_status + " " + error_message
|
||||
raise BedrockError(
|
||||
status_code=response_dict["status_code"],
|
||||
message=(
|
||||
json.dumps(error_message)
|
||||
if isinstance(error_message, dict)
|
||||
else error_message
|
||||
),
|
||||
)
|
||||
if "chunk" in parsed_response:
|
||||
chunk = parsed_response.get("chunk")
|
||||
if not chunk:
|
||||
|
|
|
@ -432,6 +432,7 @@ class Huggingface(BaseLLM):
|
|||
embed_url: str,
|
||||
) -> dict:
|
||||
data: Dict = {}
|
||||
|
||||
## TRANSFORMATION ##
|
||||
if "sentence-transformers" in model:
|
||||
if len(input) == 0:
|
||||
|
|
|
@ -6,8 +6,7 @@ model_list:
|
|||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
- model_name: openai-o1
|
||||
litellm_params:
|
||||
model: openai/random_sleep
|
||||
api_key: sk-1234
|
||||
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
|
||||
api_base: http://0.0.0.0:8090
|
||||
timeout: 2
|
||||
num_retries: 0
|
||||
|
|
|
@ -274,6 +274,7 @@ from litellm.types.llms.anthropic import (
|
|||
AnthropicResponseUsageBlock,
|
||||
)
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.router import DeploymentTypedDict
|
||||
from litellm.types.router import ModelInfo as RouterModelInfo
|
||||
from litellm.types.router import RouterGeneralSettings, updateDeployment
|
||||
from litellm.types.utils import CustomHuggingfaceTokenizer
|
||||
|
@ -6510,6 +6511,47 @@ async def model_metrics_exceptions(
|
|||
return {"data": response, "exception_types": list(exception_types)}
|
||||
|
||||
|
||||
def _get_proxy_model_info(model: dict) -> dict:
|
||||
# provided model_info in config.yaml
|
||||
model_info = model.get("model_info", {})
|
||||
|
||||
# read litellm model_prices_and_context_window.json to get the following:
|
||||
# input_cost_per_token, output_cost_per_token, max_tokens
|
||||
litellm_model_info = get_litellm_model_info(model=model)
|
||||
|
||||
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
|
||||
if litellm_model_info == {}:
|
||||
# use litellm_param model_name to get model_info
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
litellm_model = litellm_params.get("model", None)
|
||||
try:
|
||||
litellm_model_info = litellm.get_model_info(model=litellm_model)
|
||||
except Exception:
|
||||
litellm_model_info = {}
|
||||
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
|
||||
if litellm_model_info == {}:
|
||||
# use litellm_param model_name to get model_info
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
litellm_model = litellm_params.get("model", None)
|
||||
split_model = litellm_model.split("/")
|
||||
if len(split_model) > 0:
|
||||
litellm_model = split_model[-1]
|
||||
try:
|
||||
litellm_model_info = litellm.get_model_info(
|
||||
model=litellm_model, custom_llm_provider=split_model[0]
|
||||
)
|
||||
except Exception:
|
||||
litellm_model_info = {}
|
||||
for k, v in litellm_model_info.items():
|
||||
if k not in model_info:
|
||||
model_info[k] = v
|
||||
model["model_info"] = model_info
|
||||
# don't return the llm credentials
|
||||
model = remove_sensitive_info_from_deployment(deployment_dict=model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@router.get(
|
||||
"/model/info",
|
||||
tags=["model management"],
|
||||
|
@ -6598,16 +6640,15 @@ async def model_info_v1( # noqa: PLR0915
|
|||
deployment_info = llm_router.get_deployment(model_id=litellm_model_id)
|
||||
if deployment_info is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Model id = {litellm_model_id} not found on litellm proxy"
|
||||
},
|
||||
)
|
||||
_deployment_info_dict = deployment_info.model_dump()
|
||||
_deployment_info_dict = remove_sensitive_info_from_deployment(
|
||||
deployment_dict=_deployment_info_dict
|
||||
_deployment_info_dict = _get_proxy_model_info(
|
||||
model=deployment_info.model_dump(exclude_none=True)
|
||||
)
|
||||
return {"data": _deployment_info_dict}
|
||||
return {"data": [_deployment_info_dict]}
|
||||
|
||||
all_models: List[dict] = []
|
||||
model_access_groups: Dict[str, List[str]] = defaultdict(list)
|
||||
|
@ -6647,42 +6688,7 @@ async def model_info_v1( # noqa: PLR0915
|
|||
all_models = []
|
||||
|
||||
for model in all_models:
|
||||
# provided model_info in config.yaml
|
||||
model_info = model.get("model_info", {})
|
||||
|
||||
# read litellm model_prices_and_context_window.json to get the following:
|
||||
# input_cost_per_token, output_cost_per_token, max_tokens
|
||||
litellm_model_info = get_litellm_model_info(model=model)
|
||||
|
||||
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
|
||||
if litellm_model_info == {}:
|
||||
# use litellm_param model_name to get model_info
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
litellm_model = litellm_params.get("model", None)
|
||||
try:
|
||||
litellm_model_info = litellm.get_model_info(model=litellm_model)
|
||||
except Exception:
|
||||
litellm_model_info = {}
|
||||
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
|
||||
if litellm_model_info == {}:
|
||||
# use litellm_param model_name to get model_info
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
litellm_model = litellm_params.get("model", None)
|
||||
split_model = litellm_model.split("/")
|
||||
if len(split_model) > 0:
|
||||
litellm_model = split_model[-1]
|
||||
try:
|
||||
litellm_model_info = litellm.get_model_info(
|
||||
model=litellm_model, custom_llm_provider=split_model[0]
|
||||
)
|
||||
except Exception:
|
||||
litellm_model_info = {}
|
||||
for k, v in litellm_model_info.items():
|
||||
if k not in model_info:
|
||||
model_info[k] = v
|
||||
model["model_info"] = model_info
|
||||
# don't return the llm credentials
|
||||
model = remove_sensitive_info_from_deployment(deployment_dict=model)
|
||||
model = _get_proxy_model_info(model=model)
|
||||
|
||||
verbose_proxy_logger.debug("all_models: %s", all_models)
|
||||
return {"data": all_models}
|
||||
|
|
|
@ -2429,3 +2429,33 @@ async def test_bedrock_image_url_sync_client():
|
|||
except Exception as e:
|
||||
print(e)
|
||||
mock_post.assert_called_once()
|
||||
|
||||
|
||||
def test_bedrock_error_handling_streaming():
|
||||
from litellm.llms.bedrock.chat.invoke_handler import (
|
||||
AWSEventStreamDecoder,
|
||||
BedrockError,
|
||||
)
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
event = Mock()
|
||||
event.to_response_dict = Mock(
|
||||
return_value={
|
||||
"status_code": 400,
|
||||
"headers": {
|
||||
":exception-type": "serviceUnavailableException",
|
||||
":content-type": "application/json",
|
||||
":message-type": "exception",
|
||||
},
|
||||
"body": b'{"message":"Bedrock is unable to process your request."}',
|
||||
}
|
||||
)
|
||||
|
||||
decoder = AWSEventStreamDecoder(
|
||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
)
|
||||
with pytest.raises(Exception) as e:
|
||||
decoder._parse_message_from_event(event)
|
||||
assert isinstance(e.value, BedrockError)
|
||||
assert "Bedrock is unable to process your request." in e.value.message
|
||||
assert e.value.status_code == 400
|
||||
|
|
|
@ -642,19 +642,27 @@ def tgi_mock_post(*args, **kwargs):
|
|||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_hf_embedding_sentence_sim(sync_mode):
|
||||
@patch("litellm.llms.huggingface.chat.handler.async_get_hf_task_embedding_for_model")
|
||||
@patch("litellm.llms.huggingface.chat.handler.get_hf_task_embedding_for_model")
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_hf_embedding_sentence_sim(
|
||||
mock_async_get_hf_task_embedding_for_model,
|
||||
mock_get_hf_task_embedding_for_model,
|
||||
sync_mode,
|
||||
):
|
||||
try:
|
||||
# huggingface/microsoft/codebert-base
|
||||
# huggingface/facebook/bart-large
|
||||
mock_get_hf_task_embedding_for_model.return_value = "sentence-similarity"
|
||||
mock_async_get_hf_task_embedding_for_model.return_value = "sentence-similarity"
|
||||
if sync_mode is True:
|
||||
client = HTTPHandler(concurrent_limit=1)
|
||||
else:
|
||||
client = AsyncHTTPHandler(concurrent_limit=1)
|
||||
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
|
||||
data = {
|
||||
"model": "huggingface/TaylorAI/bge-micro-v2",
|
||||
"model": "huggingface/sentence-transformers/TaylorAI/bge-micro-v2",
|
||||
"input": ["good morning from litellm", "this is another item"],
|
||||
"client": client,
|
||||
}
|
||||
|
|
|
@ -88,10 +88,13 @@ async def add_models(session, model_id="123", model_name="azure-gpt-3.5"):
|
|||
return response_json
|
||||
|
||||
|
||||
async def get_model_info(session, key):
|
||||
async def get_model_info(session, key, litellm_model_id=None):
|
||||
"""
|
||||
Make sure only models user has access to are returned
|
||||
"""
|
||||
if litellm_model_id:
|
||||
url = f"http://0.0.0.0:4000/model/info?litellm_model_id={litellm_model_id}"
|
||||
else:
|
||||
url = "http://0.0.0.0:4000/model/info"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {key}",
|
||||
|
@ -148,6 +151,35 @@ async def test_get_models():
|
|||
assert m == "gpt-4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_specific_model():
|
||||
"""
|
||||
Return specific model info
|
||||
|
||||
Ensure value of model_info is same as on `/model/info` (no id set)
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
key_gen = await generate_key(session=session, models=["gpt-4"])
|
||||
key = key_gen["key"]
|
||||
response = await get_model_info(session=session, key=key)
|
||||
models = [m["model_name"] for m in response["data"]]
|
||||
model_specific_info = None
|
||||
for idx, m in enumerate(models):
|
||||
assert m == "gpt-4"
|
||||
litellm_model_id = response["data"][idx]["model_info"]["id"]
|
||||
model_specific_info = response["data"][idx]
|
||||
assert litellm_model_id is not None
|
||||
response = await get_model_info(
|
||||
session=session, key=key, litellm_model_id=litellm_model_id
|
||||
)
|
||||
assert response["data"][0]["model_info"]["id"] == litellm_model_id
|
||||
assert (
|
||||
response["data"][0] == model_specific_info
|
||||
), "Model info is not the same. Got={}, Expected={}".format(
|
||||
response["data"][0], model_specific_info
|
||||
)
|
||||
|
||||
|
||||
async def delete_model(session, model_id="123"):
|
||||
"""
|
||||
Make sure only models user has access to are returned
|
||||
|
|
|
@ -554,9 +554,7 @@ const UsagePage: React.FC<UsagePageProps> = ({
|
|||
</Col>
|
||||
<Col numColSpan={2}>
|
||||
<Card className="mb-2">
|
||||
<Title>✨ Spend by Provider</Title>
|
||||
{
|
||||
premiumUser ? (
|
||||
<Title>Spend by Provider</Title>
|
||||
<>
|
||||
<Grid numItems={2}>
|
||||
<Col numColSpan={1}>
|
||||
|
@ -592,17 +590,6 @@ const UsagePage: React.FC<UsagePageProps> = ({
|
|||
</Col>
|
||||
</Grid>
|
||||
</>
|
||||
) : (
|
||||
<div>
|
||||
<p className="mb-2 text-gray-500 italic text-[12px]">Upgrade to use this feature</p>
|
||||
<Button variant="primary" className="mb-2">
|
||||
<a href="https://forms.gle/W3U4PZpJGFHWtHyA9" target="_blank">
|
||||
Get Free Trial
|
||||
</a>
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
</Card>
|
||||
</Col>
|
||||
|
@ -643,8 +630,6 @@ const UsagePage: React.FC<UsagePageProps> = ({
|
|||
|
||||
</Card>
|
||||
|
||||
{
|
||||
premiumUser ? (
|
||||
<>
|
||||
{globalActivityPerModel.map((globalActivity, index) => (
|
||||
<Card key={index}>
|
||||
|
@ -678,68 +663,6 @@ const UsagePage: React.FC<UsagePageProps> = ({
|
|||
</Card>
|
||||
))}
|
||||
</>
|
||||
) :
|
||||
<>
|
||||
{globalActivityPerModel && globalActivityPerModel.length > 0 &&
|
||||
globalActivityPerModel.slice(0, 1).map((globalActivity, index) => (
|
||||
<Card key={index}>
|
||||
<Title>✨ Activity by Model</Title>
|
||||
<p className="mb-2 text-gray-500 italic text-[12px]">Upgrade to see analytics for all models</p>
|
||||
<Button variant="primary" className="mb-2">
|
||||
<a href="https://forms.gle/W3U4PZpJGFHWtHyA9" target="_blank">
|
||||
Get Free Trial
|
||||
</a>
|
||||
</Button>
|
||||
<Card>
|
||||
<Title>{globalActivity.model}</Title>
|
||||
<Grid numItems={2}>
|
||||
<Col>
|
||||
<Subtitle
|
||||
style={{
|
||||
fontSize: "15px",
|
||||
fontWeight: "normal",
|
||||
color: "#535452",
|
||||
}}
|
||||
>
|
||||
API Requests {valueFormatterNumbers(globalActivity.sum_api_requests)}
|
||||
</Subtitle>
|
||||
<AreaChart
|
||||
className="h-40"
|
||||
data={globalActivity.daily_data}
|
||||
index="date"
|
||||
colors={['cyan']}
|
||||
categories={['api_requests']}
|
||||
valueFormatter={valueFormatterNumbers}
|
||||
onValueChange={(v) => console.log(v)}
|
||||
/>
|
||||
</Col>
|
||||
<Col>
|
||||
<Subtitle
|
||||
style={{
|
||||
fontSize: "15px",
|
||||
fontWeight: "normal",
|
||||
color: "#535452",
|
||||
}}
|
||||
>
|
||||
Tokens {valueFormatterNumbers(globalActivity.sum_total_tokens)}
|
||||
</Subtitle>
|
||||
<BarChart
|
||||
className="h-40"
|
||||
data={globalActivity.daily_data}
|
||||
index="date"
|
||||
colors={['cyan']}
|
||||
valueFormatter={valueFormatterNumbers}
|
||||
categories={['total_tokens']}
|
||||
onValueChange={(v) => console.log(v)}
|
||||
/>
|
||||
</Col>
|
||||
|
||||
</Grid>
|
||||
</Card>
|
||||
</Card>
|
||||
))}
|
||||
</>
|
||||
}
|
||||
</Grid>
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue