forked from phoenix/litellm-mirror
Merge pull request #1874 from BerriAI/litellm_azure_base_model_pricing
[FEAT] Azure Pricing - based on base_model in model_info
This commit is contained in:
commit
98b0ace2e9
6 changed files with 121 additions and 17 deletions
|
@ -246,6 +246,28 @@ model_list:
|
||||||
$ litellm --config /path/to/config.yaml
|
$ litellm --config /path/to/config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Set Azure `base_model` for cost tracking
|
||||||
|
|
||||||
|
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
||||||
|
|
||||||
|
**Solution** ✅ : Set `base_model` on your config so litellm uses the correct model for calculating azure cost
|
||||||
|
|
||||||
|
Example config with `base_model`
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: azure-gpt-3.5
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_version: "2023-07-01-preview"
|
||||||
|
model_info:
|
||||||
|
base_model: azure/gpt-4-1106-preview
|
||||||
|
```
|
||||||
|
|
||||||
|
You can view your cost once you set up [Virtual keys](https://docs.litellm.ai/docs/proxy/virtual_keys) or [custom_callbacks](https://docs.litellm.ai/docs/proxy/logging)
|
||||||
|
|
||||||
## Load API Keys
|
## Load API Keys
|
||||||
|
|
||||||
### Load API Keys from Environment
|
### Load API Keys from Environment
|
||||||
|
|
|
@ -7,10 +7,8 @@ model_list:
|
||||||
api_version: "2023-07-01-preview"
|
api_version: "2023-07-01-preview"
|
||||||
model_info:
|
model_info:
|
||||||
mode: chat
|
mode: chat
|
||||||
input_cost_per_token: 0.0.00006
|
|
||||||
output_cost_per_token: 0.00003
|
|
||||||
max_tokens: 4096
|
max_tokens: 4096
|
||||||
base_model: gpt-3.5-turbo
|
base_model: azure/gpt-4-1106-preview
|
||||||
- model_name: gpt-4
|
- model_name: gpt-4
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: azure/chatgpt-v-2
|
||||||
|
@ -74,11 +72,6 @@ litellm_settings:
|
||||||
max_budget: 1.5000
|
max_budget: 1.5000
|
||||||
models: ["azure-gpt-3.5"]
|
models: ["azure-gpt-3.5"]
|
||||||
duration: None
|
duration: None
|
||||||
cache: True # set cache responses to True
|
|
||||||
cache_params:
|
|
||||||
type: "redis-semantic"
|
|
||||||
similarity_threshold: 0.8
|
|
||||||
redis_semantic_cache_embedding_model: azure-embedding-model
|
|
||||||
upperbound_key_generate_params:
|
upperbound_key_generate_params:
|
||||||
max_budget: 100
|
max_budget: 100
|
||||||
duration: "30d"
|
duration: "30d"
|
||||||
|
|
|
@ -304,7 +304,10 @@ class Router:
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
)
|
)
|
||||||
kwargs.setdefault("metadata", {}).update(
|
kwargs.setdefault("metadata", {}).update(
|
||||||
{"deployment": deployment["litellm_params"]["model"]}
|
{
|
||||||
|
"deployment": deployment["litellm_params"]["model"],
|
||||||
|
"model_info": deployment.get("model_info", {}),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
|
@ -376,7 +379,10 @@ class Router:
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
)
|
)
|
||||||
kwargs.setdefault("metadata", {}).update(
|
kwargs.setdefault("metadata", {}).update(
|
||||||
{"deployment": deployment["litellm_params"]["model"]}
|
{
|
||||||
|
"deployment": deployment["litellm_params"]["model"],
|
||||||
|
"model_info": deployment.get("model_info", {}),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
|
@ -451,7 +457,10 @@ class Router:
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
)
|
)
|
||||||
kwargs.setdefault("metadata", {}).update(
|
kwargs.setdefault("metadata", {}).update(
|
||||||
{"deployment": deployment["litellm_params"]["model"]}
|
{
|
||||||
|
"deployment": deployment["litellm_params"]["model"],
|
||||||
|
"model_info": deployment.get("model_info", {}),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
|
@ -526,7 +535,10 @@ class Router:
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
)
|
)
|
||||||
kwargs.setdefault("metadata", {}).update(
|
kwargs.setdefault("metadata", {}).update(
|
||||||
{"deployment": deployment["litellm_params"]["model"]}
|
{
|
||||||
|
"deployment": deployment["litellm_params"]["model"],
|
||||||
|
"model_info": deployment.get("model_info", {}),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
|
@ -654,7 +666,10 @@ class Router:
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
)
|
)
|
||||||
kwargs.setdefault("metadata", {}).update(
|
kwargs.setdefault("metadata", {}).update(
|
||||||
{"deployment": deployment["litellm_params"]["model"]}
|
{
|
||||||
|
"deployment": deployment["litellm_params"]["model"],
|
||||||
|
"model_info": deployment.get("model_info", {}),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
|
@ -780,7 +795,10 @@ class Router:
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
)
|
)
|
||||||
kwargs.setdefault("metadata", {}).update(
|
kwargs.setdefault("metadata", {}).update(
|
||||||
{"deployment": deployment["litellm_params"]["model"]}
|
{
|
||||||
|
"deployment": deployment["litellm_params"]["model"],
|
||||||
|
"model_info": deployment.get("model_info", {}),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
|
@ -1407,7 +1425,6 @@ class Router:
|
||||||
max_retries = litellm.get_secret(max_retries_env_name)
|
max_retries = litellm.get_secret(max_retries_env_name)
|
||||||
litellm_params["max_retries"] = max_retries
|
litellm_params["max_retries"] = max_retries
|
||||||
|
|
||||||
|
|
||||||
# proxy support
|
# proxy support
|
||||||
import os
|
import os
|
||||||
import httpx
|
import httpx
|
||||||
|
|
|
@ -1743,7 +1743,7 @@ def test_azure_cloudflare_api():
|
||||||
|
|
||||||
def test_completion_anyscale_2():
|
def test_completion_anyscale_2():
|
||||||
try:
|
try:
|
||||||
# litellm.set_verbose= True
|
# litellm.set_verbose = True
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You're a good bot"},
|
{"role": "system", "content": "You're a good bot"},
|
||||||
{
|
{
|
||||||
|
|
|
@ -256,6 +256,7 @@ class CompletionCustomHandler(
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
self.states.append("async_success")
|
self.states.append("async_success")
|
||||||
|
print("in async success, kwargs: ", kwargs)
|
||||||
## START TIME
|
## START TIME
|
||||||
assert isinstance(start_time, datetime)
|
assert isinstance(start_time, datetime)
|
||||||
## END TIME
|
## END TIME
|
||||||
|
@ -266,6 +267,38 @@ class CompletionCustomHandler(
|
||||||
)
|
)
|
||||||
## KWARGS
|
## KWARGS
|
||||||
assert isinstance(kwargs["model"], str)
|
assert isinstance(kwargs["model"], str)
|
||||||
|
|
||||||
|
# checking we use base_model for azure cost calculation
|
||||||
|
base_model = litellm.utils._get_base_model_from_metadata(
|
||||||
|
model_call_details=kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
kwargs["model"] == "chatgpt-v-2"
|
||||||
|
and base_model is not None
|
||||||
|
and kwargs["stream"] != True
|
||||||
|
):
|
||||||
|
# when base_model is set for azure, we should use pricing for the base_model
|
||||||
|
# this checks response_cost == litellm.cost_per_token(model=base_model)
|
||||||
|
assert isinstance(kwargs["response_cost"], float)
|
||||||
|
response_cost = kwargs["response_cost"]
|
||||||
|
print(
|
||||||
|
f"response_cost: {response_cost}, for model: {kwargs['model']} and base_model: {base_model}"
|
||||||
|
)
|
||||||
|
prompt_tokens = response_obj.usage.prompt_tokens
|
||||||
|
completion_tokens = response_obj.usage.completion_tokens
|
||||||
|
# ensure the pricing is based on the base_model here
|
||||||
|
prompt_price, completion_price = litellm.cost_per_token(
|
||||||
|
model=base_model,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
expected_price = prompt_price + completion_price
|
||||||
|
print(f"expected price: {expected_price}")
|
||||||
|
assert (
|
||||||
|
response_cost == expected_price
|
||||||
|
), f"response_cost: {response_cost} != expected_price: {expected_price}. For model: {kwargs['model']} and base_model: {base_model}. should have used base_model for price"
|
||||||
|
|
||||||
assert isinstance(kwargs["messages"], list)
|
assert isinstance(kwargs["messages"], list)
|
||||||
assert isinstance(kwargs["optional_params"], dict)
|
assert isinstance(kwargs["optional_params"], dict)
|
||||||
assert isinstance(kwargs["litellm_params"], dict)
|
assert isinstance(kwargs["litellm_params"], dict)
|
||||||
|
@ -345,6 +378,7 @@ async def test_async_chat_azure():
|
||||||
customHandler_streaming_azure_router = CompletionCustomHandler()
|
customHandler_streaming_azure_router = CompletionCustomHandler()
|
||||||
customHandler_failure = CompletionCustomHandler()
|
customHandler_failure = CompletionCustomHandler()
|
||||||
litellm.callbacks = [customHandler_completion_azure_router]
|
litellm.callbacks = [customHandler_completion_azure_router]
|
||||||
|
litellm.set_verbose = True
|
||||||
model_list = [
|
model_list = [
|
||||||
{
|
{
|
||||||
"model_name": "gpt-3.5-turbo", # openai model name
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
@ -354,6 +388,7 @@ async def test_async_chat_azure():
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
"api_base": os.getenv("AZURE_API_BASE"),
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
},
|
},
|
||||||
|
"model_info": {"base_model": "azure/gpt-4-1106-preview"},
|
||||||
"tpm": 240000,
|
"tpm": 240000,
|
||||||
"rpm": 1800,
|
"rpm": 1800,
|
||||||
},
|
},
|
||||||
|
|
|
@ -1079,10 +1079,17 @@ class Logging:
|
||||||
call_type=self.call_type,
|
call_type=self.call_type,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# check if base_model set on azure
|
||||||
|
base_model = _get_base_model_from_metadata(
|
||||||
|
model_call_details=self.model_call_details
|
||||||
|
)
|
||||||
|
# base_model defaults to None if not set on model_info
|
||||||
self.model_call_details[
|
self.model_call_details[
|
||||||
"response_cost"
|
"response_cost"
|
||||||
] = litellm.completion_cost(
|
] = litellm.completion_cost(
|
||||||
completion_response=result, call_type=self.call_type
|
completion_response=result,
|
||||||
|
call_type=self.call_type,
|
||||||
|
model=base_model,
|
||||||
)
|
)
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
||||||
|
@ -1158,10 +1165,16 @@ class Logging:
|
||||||
if self.model_call_details.get("cache_hit", False) == True:
|
if self.model_call_details.get("cache_hit", False) == True:
|
||||||
self.model_call_details["response_cost"] = 0.0
|
self.model_call_details["response_cost"] = 0.0
|
||||||
else:
|
else:
|
||||||
|
# check if base_model set on azure
|
||||||
|
base_model = _get_base_model_from_metadata(
|
||||||
|
model_call_details=self.model_call_details
|
||||||
|
)
|
||||||
|
# base_model defaults to None if not set on model_info
|
||||||
self.model_call_details[
|
self.model_call_details[
|
||||||
"response_cost"
|
"response_cost"
|
||||||
] = litellm.completion_cost(
|
] = litellm.completion_cost(
|
||||||
completion_response=complete_streaming_response,
|
completion_response=complete_streaming_response,
|
||||||
|
model=base_model,
|
||||||
)
|
)
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
||||||
|
@ -1479,8 +1492,14 @@ class Logging:
|
||||||
if self.model_call_details.get("cache_hit", False) == True:
|
if self.model_call_details.get("cache_hit", False) == True:
|
||||||
self.model_call_details["response_cost"] = 0.0
|
self.model_call_details["response_cost"] = 0.0
|
||||||
else:
|
else:
|
||||||
|
# check if base_model set on azure
|
||||||
|
base_model = _get_base_model_from_metadata(
|
||||||
|
model_call_details=self.model_call_details
|
||||||
|
)
|
||||||
|
# base_model defaults to None if not set on model_info
|
||||||
self.model_call_details["response_cost"] = litellm.completion_cost(
|
self.model_call_details["response_cost"] = litellm.completion_cost(
|
||||||
completion_response=complete_streaming_response,
|
completion_response=complete_streaming_response,
|
||||||
|
model=base_model,
|
||||||
)
|
)
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
||||||
|
@ -9231,3 +9250,21 @@ def get_logging_id(start_time, response_obj):
|
||||||
return response_id
|
return response_id
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_base_model_from_metadata(model_call_details=None):
|
||||||
|
if model_call_details is None:
|
||||||
|
return None
|
||||||
|
litellm_params = model_call_details.get("litellm_params", {})
|
||||||
|
|
||||||
|
if litellm_params is not None:
|
||||||
|
metadata = litellm_params.get("metadata", {})
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
model_info = metadata.get("model_info", {})
|
||||||
|
|
||||||
|
if model_info is not None:
|
||||||
|
base_model = model_info.get("base_model", None)
|
||||||
|
if base_model is not None:
|
||||||
|
return base_model
|
||||||
|
return None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue