forked from phoenix/litellm-mirror
feat(main.py): allow updating model cost via completion()
This commit is contained in:
parent
4c1ef4e270
commit
0d200cd8dc
3 changed files with 27 additions and 3 deletions
|
@ -331,6 +331,9 @@ def completion(
|
||||||
num_retries = kwargs.get("num_retries", None) ## deprecated
|
num_retries = kwargs.get("num_retries", None) ## deprecated
|
||||||
max_retries = kwargs.get("max_retries", None)
|
max_retries = kwargs.get("max_retries", None)
|
||||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||||
|
### CUSTOM MODEL COST ###
|
||||||
|
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||||
|
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||||
### CUSTOM PROMPT TEMPLATE ###
|
### CUSTOM PROMPT TEMPLATE ###
|
||||||
initial_prompt_value = kwargs.get("initial_prompt_value", None)
|
initial_prompt_value = kwargs.get("initial_prompt_value", None)
|
||||||
roles = kwargs.get("roles", None)
|
roles = kwargs.get("roles", None)
|
||||||
|
@ -341,7 +344,7 @@ def completion(
|
||||||
client = kwargs.get("client", None)
|
client = kwargs.get("client", None)
|
||||||
######## end of unpacking kwargs ###########
|
######## end of unpacking kwargs ###########
|
||||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
||||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm"]
|
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token"]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||||
if mock_response:
|
if mock_response:
|
||||||
|
@ -375,6 +378,17 @@ def completion(
|
||||||
model=deployment_id
|
model=deployment_id
|
||||||
custom_llm_provider="azure"
|
custom_llm_provider="azure"
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
|
||||||
|
|
||||||
|
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
||||||
|
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||||
|
litellm.register_model({
|
||||||
|
model: {
|
||||||
|
"input_cost_per_token": input_cost_per_token,
|
||||||
|
"output_cost_per_token": output_cost_per_token,
|
||||||
|
"litellm_provider": custom_llm_provider
|
||||||
|
}
|
||||||
|
})
|
||||||
|
### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ###
|
||||||
custom_prompt_dict = {} # type: ignore
|
custom_prompt_dict = {} # type: ignore
|
||||||
if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token:
|
if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token:
|
||||||
custom_prompt_dict = {model: {}}
|
custom_prompt_dict = {model: {}}
|
||||||
|
|
|
@ -33,4 +33,15 @@ def test_update_model_cost_map_url():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An error occurred: {e}")
|
pytest.fail(f"An error occurred: {e}")
|
||||||
|
|
||||||
test_update_model_cost_map_url()
|
# test_update_model_cost_map_url()
|
||||||
|
|
||||||
|
def test_update_model_cost_via_completion():
|
||||||
|
try:
|
||||||
|
response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}], input_cost_per_token=0.3, output_cost_per_token=0.4)
|
||||||
|
print(f"litellm.model_cost for gpt-3.5-turbo: {litellm.model_cost['gpt-3.5-turbo']}")
|
||||||
|
assert litellm.model_cost["gpt-3.5-turbo"]["input_cost_per_token"] == 0.3
|
||||||
|
assert litellm.model_cost["gpt-3.5-turbo"]["output_cost_per_token"] == 0.4
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
test_update_model_cost_via_completion()
|
|
@ -1794,7 +1794,6 @@ def register_model(model_cost: Union[str, dict]):
|
||||||
for key, value in loaded_model_cost.items():
|
for key, value in loaded_model_cost.items():
|
||||||
## override / add new keys to the existing model cost dictionary
|
## override / add new keys to the existing model cost dictionary
|
||||||
litellm.model_cost[key] = loaded_model_cost[key]
|
litellm.model_cost[key] = loaded_model_cost[key]
|
||||||
|
|
||||||
# add new model names to provider lists
|
# add new model names to provider lists
|
||||||
if value.get('litellm_provider') == 'openai':
|
if value.get('litellm_provider') == 'openai':
|
||||||
if key not in litellm.open_ai_chat_completion_models:
|
if key not in litellm.open_ai_chat_completion_models:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue