feat(main.py): allow updating model cost via completion()

This commit is contained in:
Krrish Dholakia 2023-11-29 20:14:31 -08:00
parent 4c1ef4e270
commit 0d200cd8dc
3 changed files with 27 additions and 3 deletions

View file

@ -331,6 +331,9 @@ def completion(
num_retries = kwargs.get("num_retries", None) ## deprecated
max_retries = kwargs.get("max_retries", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None)
### CUSTOM PROMPT TEMPLATE ###
initial_prompt_value = kwargs.get("initial_prompt_value", None)
roles = kwargs.get("roles", None)
@ -341,7 +344,7 @@ def completion(
client = kwargs.get("client", None)
######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response:
@ -375,6 +378,17 @@ def completion(
model=deployment_id
custom_llm_provider="azure"
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None:
litellm.register_model({
model: {
"input_cost_per_token": input_cost_per_token,
"output_cost_per_token": output_cost_per_token,
"litellm_provider": custom_llm_provider
}
})
### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ###
custom_prompt_dict = {} # type: ignore
if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token:
custom_prompt_dict = {model: {}}

View file

@ -33,4 +33,15 @@ def test_update_model_cost_map_url():
except Exception as e:
pytest.fail(f"An error occurred: {e}")
test_update_model_cost_map_url()
# test_update_model_cost_map_url()
def test_update_model_cost_via_completion():
try:
response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}], input_cost_per_token=0.3, output_cost_per_token=0.4)
print(f"litellm.model_cost for gpt-3.5-turbo: {litellm.model_cost['gpt-3.5-turbo']}")
assert litellm.model_cost["gpt-3.5-turbo"]["input_cost_per_token"] == 0.3
assert litellm.model_cost["gpt-3.5-turbo"]["output_cost_per_token"] == 0.4
except Exception as e:
pytest.fail(f"An error occurred: {e}")
test_update_model_cost_via_completion()

View file

@ -1794,7 +1794,6 @@ def register_model(model_cost: Union[str, dict]):
for key, value in loaded_model_cost.items():
## override / add new keys to the existing model cost dictionary
litellm.model_cost[key] = loaded_model_cost[key]
# add new model names to provider lists
if value.get('litellm_provider') == 'openai':
if key not in litellm.open_ai_chat_completion_models: