forked from phoenix/litellm-mirror
Merge pull request #1874 from BerriAI/litellm_azure_base_model_pricing
[FEAT] Azure Pricing - based on base_model in model_info
This commit is contained in:
commit
98b0ace2e9
6 changed files with 121 additions and 17 deletions
|
@ -246,6 +246,28 @@ model_list:
|
|||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
|
||||
## Set Azure `base_model` for cost tracking
|
||||
|
||||
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
||||
|
||||
**Solution** ✅ : Set `base_model` on your config so litellm uses the correct model for calculating azure cost
|
||||
|
||||
Example config with `base_model`
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: azure-gpt-3.5
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
model_info:
|
||||
base_model: azure/gpt-4-1106-preview
|
||||
```
|
||||
|
||||
You can view your cost once you set up [Virtual keys](https://docs.litellm.ai/docs/proxy/virtual_keys) or [custom_callbacks](https://docs.litellm.ai/docs/proxy/logging)
|
||||
|
||||
## Load API Keys
|
||||
|
||||
### Load API Keys from Environment
|
||||
|
|
|
@ -7,10 +7,8 @@ model_list:
|
|||
api_version: "2023-07-01-preview"
|
||||
model_info:
|
||||
mode: chat
|
||||
input_cost_per_token: 0.0.00006
|
||||
output_cost_per_token: 0.00003
|
||||
max_tokens: 4096
|
||||
base_model: gpt-3.5-turbo
|
||||
base_model: azure/gpt-4-1106-preview
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
|
@ -74,11 +72,6 @@ litellm_settings:
|
|||
max_budget: 1.5000
|
||||
models: ["azure-gpt-3.5"]
|
||||
duration: None
|
||||
cache: True # set cache responses to True
|
||||
cache_params:
|
||||
type: "redis-semantic"
|
||||
similarity_threshold: 0.8
|
||||
redis_semantic_cache_embedding_model: azure-embedding-model
|
||||
upperbound_key_generate_params:
|
||||
max_budget: 100
|
||||
duration: "30d"
|
||||
|
|
|
@ -304,7 +304,10 @@ class Router:
|
|||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"deployment": deployment["litellm_params"]["model"]}
|
||||
{
|
||||
"deployment": deployment["litellm_params"]["model"],
|
||||
"model_info": deployment.get("model_info", {}),
|
||||
}
|
||||
)
|
||||
data = deployment["litellm_params"].copy()
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
|
@ -376,7 +379,10 @@ class Router:
|
|||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"deployment": deployment["litellm_params"]["model"]}
|
||||
{
|
||||
"deployment": deployment["litellm_params"]["model"],
|
||||
"model_info": deployment.get("model_info", {}),
|
||||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
|
@ -451,7 +457,10 @@ class Router:
|
|||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"deployment": deployment["litellm_params"]["model"]}
|
||||
{
|
||||
"deployment": deployment["litellm_params"]["model"],
|
||||
"model_info": deployment.get("model_info", {}),
|
||||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
|
@ -526,7 +535,10 @@ class Router:
|
|||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"deployment": deployment["litellm_params"]["model"]}
|
||||
{
|
||||
"deployment": deployment["litellm_params"]["model"],
|
||||
"model_info": deployment.get("model_info", {}),
|
||||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
|
@ -654,7 +666,10 @@ class Router:
|
|||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"deployment": deployment["litellm_params"]["model"]}
|
||||
{
|
||||
"deployment": deployment["litellm_params"]["model"],
|
||||
"model_info": deployment.get("model_info", {}),
|
||||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
|
@ -780,7 +795,10 @@ class Router:
|
|||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"deployment": deployment["litellm_params"]["model"]}
|
||||
{
|
||||
"deployment": deployment["litellm_params"]["model"],
|
||||
"model_info": deployment.get("model_info", {}),
|
||||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
|
@ -1407,7 +1425,6 @@ class Router:
|
|||
max_retries = litellm.get_secret(max_retries_env_name)
|
||||
litellm_params["max_retries"] = max_retries
|
||||
|
||||
|
||||
# proxy support
|
||||
import os
|
||||
import httpx
|
||||
|
|
|
@ -1743,7 +1743,7 @@ def test_azure_cloudflare_api():
|
|||
|
||||
def test_completion_anyscale_2():
|
||||
try:
|
||||
# litellm.set_verbose= True
|
||||
# litellm.set_verbose = True
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{
|
||||
|
|
|
@ -256,6 +256,7 @@ class CompletionCustomHandler(
|
|||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_success")
|
||||
print("in async success, kwargs: ", kwargs)
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
|
@ -266,6 +267,38 @@ class CompletionCustomHandler(
|
|||
)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs["model"], str)
|
||||
|
||||
# checking we use base_model for azure cost calculation
|
||||
base_model = litellm.utils._get_base_model_from_metadata(
|
||||
model_call_details=kwargs
|
||||
)
|
||||
|
||||
if (
|
||||
kwargs["model"] == "chatgpt-v-2"
|
||||
and base_model is not None
|
||||
and kwargs["stream"] != True
|
||||
):
|
||||
# when base_model is set for azure, we should use pricing for the base_model
|
||||
# this checks response_cost == litellm.cost_per_token(model=base_model)
|
||||
assert isinstance(kwargs["response_cost"], float)
|
||||
response_cost = kwargs["response_cost"]
|
||||
print(
|
||||
f"response_cost: {response_cost}, for model: {kwargs['model']} and base_model: {base_model}"
|
||||
)
|
||||
prompt_tokens = response_obj.usage.prompt_tokens
|
||||
completion_tokens = response_obj.usage.completion_tokens
|
||||
# ensure the pricing is based on the base_model here
|
||||
prompt_price, completion_price = litellm.cost_per_token(
|
||||
model=base_model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
expected_price = prompt_price + completion_price
|
||||
print(f"expected price: {expected_price}")
|
||||
assert (
|
||||
response_cost == expected_price
|
||||
), f"response_cost: {response_cost} != expected_price: {expected_price}. For model: {kwargs['model']} and base_model: {base_model}. should have used base_model for price"
|
||||
|
||||
assert isinstance(kwargs["messages"], list)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
|
@ -345,6 +378,7 @@ async def test_async_chat_azure():
|
|||
customHandler_streaming_azure_router = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_completion_azure_router]
|
||||
litellm.set_verbose = True
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
|
@ -354,6 +388,7 @@ async def test_async_chat_azure():
|
|||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"model_info": {"base_model": "azure/gpt-4-1106-preview"},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
|
|
|
@ -1079,10 +1079,17 @@ class Logging:
|
|||
call_type=self.call_type,
|
||||
)
|
||||
else:
|
||||
# check if base_model set on azure
|
||||
base_model = _get_base_model_from_metadata(
|
||||
model_call_details=self.model_call_details
|
||||
)
|
||||
# base_model defaults to None if not set on model_info
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = litellm.completion_cost(
|
||||
completion_response=result, call_type=self.call_type
|
||||
completion_response=result,
|
||||
call_type=self.call_type,
|
||||
model=base_model,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
||||
|
@ -1158,10 +1165,16 @@ class Logging:
|
|||
if self.model_call_details.get("cache_hit", False) == True:
|
||||
self.model_call_details["response_cost"] = 0.0
|
||||
else:
|
||||
# check if base_model set on azure
|
||||
base_model = _get_base_model_from_metadata(
|
||||
model_call_details=self.model_call_details
|
||||
)
|
||||
# base_model defaults to None if not set on model_info
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = litellm.completion_cost(
|
||||
completion_response=complete_streaming_response,
|
||||
model=base_model,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
||||
|
@ -1479,8 +1492,14 @@ class Logging:
|
|||
if self.model_call_details.get("cache_hit", False) == True:
|
||||
self.model_call_details["response_cost"] = 0.0
|
||||
else:
|
||||
# check if base_model set on azure
|
||||
base_model = _get_base_model_from_metadata(
|
||||
model_call_details=self.model_call_details
|
||||
)
|
||||
# base_model defaults to None if not set on model_info
|
||||
self.model_call_details["response_cost"] = litellm.completion_cost(
|
||||
completion_response=complete_streaming_response,
|
||||
model=base_model,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
||||
|
@ -9231,3 +9250,21 @@ def get_logging_id(start_time, response_obj):
|
|||
return response_id
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def _get_base_model_from_metadata(model_call_details=None):
|
||||
if model_call_details is None:
|
||||
return None
|
||||
litellm_params = model_call_details.get("litellm_params", {})
|
||||
|
||||
if litellm_params is not None:
|
||||
metadata = litellm_params.get("metadata", {})
|
||||
|
||||
if metadata is not None:
|
||||
model_info = metadata.get("model_info", {})
|
||||
|
||||
if model_info is not None:
|
||||
base_model = model_info.get("base_model", None)
|
||||
if base_model is not None:
|
||||
return base_model
|
||||
return None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue