mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(main.py): support custom pricing for embedding calls
This commit is contained in:
parent
39a1b4c3b5
commit
2ce4258cc0
2 changed files with 34 additions and 1 deletions
|
@ -536,6 +536,8 @@ def completion(
|
||||||
"tpm",
|
"tpm",
|
||||||
"input_cost_per_token",
|
"input_cost_per_token",
|
||||||
"output_cost_per_token",
|
"output_cost_per_token",
|
||||||
|
"input_cost_per_second",
|
||||||
|
"output_cost_per_second",
|
||||||
"hf_model_name",
|
"hf_model_name",
|
||||||
"model_info",
|
"model_info",
|
||||||
"proxy_server_request",
|
"proxy_server_request",
|
||||||
|
@ -2262,6 +2264,11 @@ def embedding(
|
||||||
encoding_format = kwargs.get("encoding_format", None)
|
encoding_format = kwargs.get("encoding_format", None)
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
aembedding = kwargs.get("aembedding", None)
|
aembedding = kwargs.get("aembedding", None)
|
||||||
|
### CUSTOM MODEL COST ###
|
||||||
|
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||||
|
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||||
|
input_cost_per_second = kwargs.get("input_cost_per_second", None)
|
||||||
|
output_cost_per_second = kwargs.get("output_cost_per_second", None)
|
||||||
openai_params = [
|
openai_params = [
|
||||||
"user",
|
"user",
|
||||||
"request_timeout",
|
"request_timeout",
|
||||||
|
@ -2310,6 +2317,8 @@ def embedding(
|
||||||
"tpm",
|
"tpm",
|
||||||
"input_cost_per_token",
|
"input_cost_per_token",
|
||||||
"output_cost_per_token",
|
"output_cost_per_token",
|
||||||
|
"input_cost_per_second",
|
||||||
|
"output_cost_per_second",
|
||||||
"hf_model_name",
|
"hf_model_name",
|
||||||
"proxy_server_request",
|
"proxy_server_request",
|
||||||
"model_info",
|
"model_info",
|
||||||
|
@ -2335,6 +2344,28 @@ def embedding(
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
**non_default_params,
|
**non_default_params,
|
||||||
)
|
)
|
||||||
|
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
||||||
|
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||||
|
litellm.register_model(
|
||||||
|
{
|
||||||
|
model: {
|
||||||
|
"input_cost_per_token": input_cost_per_token,
|
||||||
|
"output_cost_per_token": output_cost_per_token,
|
||||||
|
"litellm_provider": custom_llm_provider,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if input_cost_per_second is not None: # time based pricing just needs cost in place
|
||||||
|
output_cost_per_second = output_cost_per_second or 0.0
|
||||||
|
litellm.register_model(
|
||||||
|
{
|
||||||
|
model: {
|
||||||
|
"input_cost_per_second": input_cost_per_second,
|
||||||
|
"output_cost_per_second": output_cost_per_second,
|
||||||
|
"litellm_provider": custom_llm_provider,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
response = None
|
response = None
|
||||||
logging = litellm_logging_obj
|
logging = litellm_logging_obj
|
||||||
|
|
|
@ -2371,7 +2371,9 @@ def client(original_function):
|
||||||
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
||||||
"id", None
|
"id", None
|
||||||
)
|
)
|
||||||
if isinstance(result, ModelResponse):
|
if isinstance(result, ModelResponse) or isinstance(
|
||||||
|
result, EmbeddingResponse
|
||||||
|
):
|
||||||
result._response_ms = (
|
result._response_ms = (
|
||||||
end_time - start_time
|
end_time - start_time
|
||||||
).total_seconds() * 1000 # return response latency in ms like openai
|
).total_seconds() * 1000 # return response latency in ms like openai
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue