forked from phoenix/litellm-mirror
feat(main.py): allow updating model cost via completion()
This commit is contained in:
parent
4c1ef4e270
commit
0d200cd8dc
3 changed files with 27 additions and 3 deletions
|
@ -331,6 +331,9 @@ def completion(
|
|||
num_retries = kwargs.get("num_retries", None) ## deprecated
|
||||
max_retries = kwargs.get("max_retries", None)
|
||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||
### CUSTOM MODEL COST ###
|
||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||
### CUSTOM PROMPT TEMPLATE ###
|
||||
initial_prompt_value = kwargs.get("initial_prompt_value", None)
|
||||
roles = kwargs.get("roles", None)
|
||||
|
@ -341,7 +344,7 @@ def completion(
|
|||
client = kwargs.get("client", None)
|
||||
######## end of unpacking kwargs ###########
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
if mock_response:
|
||||
|
@ -375,6 +378,17 @@ def completion(
|
|||
model=deployment_id
|
||||
custom_llm_provider="azure"
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
|
||||
|
||||
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
||||
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||
litellm.register_model({
|
||||
model: {
|
||||
"input_cost_per_token": input_cost_per_token,
|
||||
"output_cost_per_token": output_cost_per_token,
|
||||
"litellm_provider": custom_llm_provider
|
||||
}
|
||||
})
|
||||
### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ###
|
||||
custom_prompt_dict = {} # type: ignore
|
||||
if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token:
|
||||
custom_prompt_dict = {model: {}}
|
||||
|
|
|
@ -33,4 +33,15 @@ def test_update_model_cost_map_url():
|
|||
except Exception as e:
|
||||
pytest.fail(f"An error occurred: {e}")
|
||||
|
||||
test_update_model_cost_map_url()
|
||||
# test_update_model_cost_map_url()
|
||||
|
||||
def test_update_model_cost_via_completion():
|
||||
try:
|
||||
response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}], input_cost_per_token=0.3, output_cost_per_token=0.4)
|
||||
print(f"litellm.model_cost for gpt-3.5-turbo: {litellm.model_cost['gpt-3.5-turbo']}")
|
||||
assert litellm.model_cost["gpt-3.5-turbo"]["input_cost_per_token"] == 0.3
|
||||
assert litellm.model_cost["gpt-3.5-turbo"]["output_cost_per_token"] == 0.4
|
||||
except Exception as e:
|
||||
pytest.fail(f"An error occurred: {e}")
|
||||
|
||||
test_update_model_cost_via_completion()
|
|
@ -1794,7 +1794,6 @@ def register_model(model_cost: Union[str, dict]):
|
|||
for key, value in loaded_model_cost.items():
|
||||
## override / add new keys to the existing model cost dictionary
|
||||
litellm.model_cost[key] = loaded_model_cost[key]
|
||||
|
||||
# add new model names to provider lists
|
||||
if value.get('litellm_provider') == 'openai':
|
||||
if key not in litellm.open_ai_chat_completion_models:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue