mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat(databricks.py): support vertex mistral cost tracking
This commit is contained in:
parent
6f9c29d39b
commit
6d5aedc48d
4 changed files with 56 additions and 87 deletions
|
@ -509,7 +509,7 @@ def completion_cost(
|
||||||
):
|
):
|
||||||
model = completion_response._hidden_params.get("model", model)
|
model = completion_response._hidden_params.get("model", model)
|
||||||
custom_llm_provider = completion_response._hidden_params.get(
|
custom_llm_provider = completion_response._hidden_params.get(
|
||||||
"custom_llm_provider", ""
|
"custom_llm_provider", custom_llm_provider or ""
|
||||||
)
|
)
|
||||||
region_name = completion_response._hidden_params.get(
|
region_name = completion_response._hidden_params.get(
|
||||||
"region_name", region_name
|
"region_name", region_name
|
||||||
|
@ -732,7 +732,7 @@ def response_cost_calculator(
|
||||||
)
|
)
|
||||||
return response_cost
|
return response_cost
|
||||||
except litellm.NotFoundError as e:
|
except litellm.NotFoundError as e:
|
||||||
print_verbose(
|
verbose_logger.debug( # debug since it can be spammy in logs, for calls
|
||||||
f"Model={model} for LLM Provider={custom_llm_provider} not found in completion cost map."
|
f"Model={model} for LLM Provider={custom_llm_provider} not found in completion cost map."
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -259,87 +259,6 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
api_base = "{}/embeddings".format(api_base)
|
api_base = "{}/embeddings".format(api_base)
|
||||||
return api_base, headers
|
return api_base, headers
|
||||||
|
|
||||||
def process_response(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
response: Union[requests.Response, httpx.Response],
|
|
||||||
model_response: ModelResponse,
|
|
||||||
stream: bool,
|
|
||||||
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
|
||||||
optional_params: dict,
|
|
||||||
api_key: str,
|
|
||||||
data: Union[dict, str],
|
|
||||||
messages: List,
|
|
||||||
print_verbose,
|
|
||||||
encoding,
|
|
||||||
) -> ModelResponse:
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=api_key,
|
|
||||||
original_response=response.text,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
|
||||||
print_verbose(f"raw model_response: {response.text}")
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
try:
|
|
||||||
completion_response = response.json()
|
|
||||||
except:
|
|
||||||
raise DatabricksError(
|
|
||||||
message=response.text, status_code=response.status_code
|
|
||||||
)
|
|
||||||
if "error" in completion_response:
|
|
||||||
raise DatabricksError(
|
|
||||||
message=str(completion_response["error"]),
|
|
||||||
status_code=response.status_code,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
text_content = ""
|
|
||||||
tool_calls = []
|
|
||||||
for content in completion_response["content"]:
|
|
||||||
if content["type"] == "text":
|
|
||||||
text_content += content["text"]
|
|
||||||
## TOOL CALLING
|
|
||||||
elif content["type"] == "tool_use":
|
|
||||||
tool_calls.append(
|
|
||||||
{
|
|
||||||
"id": content["id"],
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": content["name"],
|
|
||||||
"arguments": json.dumps(content["input"]),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
_message = litellm.Message(
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
content=text_content or None,
|
|
||||||
)
|
|
||||||
model_response.choices[0].message = _message # type: ignore
|
|
||||||
model_response._hidden_params["original_response"] = completion_response[
|
|
||||||
"content"
|
|
||||||
] # allow user to access raw anthropic tool calling response
|
|
||||||
|
|
||||||
model_response.choices[0].finish_reason = map_finish_reason(
|
|
||||||
completion_response["stop_reason"]
|
|
||||||
)
|
|
||||||
|
|
||||||
## CALCULATING USAGE
|
|
||||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
|
||||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
|
||||||
|
|
||||||
model_response.created = int(time.time())
|
|
||||||
model_response.model = model
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=total_tokens,
|
|
||||||
)
|
|
||||||
setattr(model_response, "usage", usage) # type: ignore
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
async def acompletion_stream_function(
|
async def acompletion_stream_function(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -392,6 +311,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
base_model: Optional[str],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -420,7 +340,11 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DatabricksError(status_code=500, message=str(e))
|
raise DatabricksError(status_code=500, message=str(e))
|
||||||
|
|
||||||
return ModelResponse(**response_json)
|
response = ModelResponse(**response_json)
|
||||||
|
|
||||||
|
if base_model is not None:
|
||||||
|
response._hidden_params["model"] = base_model
|
||||||
|
return response
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -443,6 +367,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None)
|
custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None)
|
||||||
|
base_model: Optional[str] = optional_params.pop("base_model", None)
|
||||||
api_base, headers = self._validate_environment(
|
api_base, headers = self._validate_environment(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -476,11 +401,11 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
"headers": headers,
|
"headers": headers,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if acompletion == True:
|
if acompletion is True:
|
||||||
if client is not None and isinstance(client, HTTPHandler):
|
if client is not None and isinstance(client, HTTPHandler):
|
||||||
client = None
|
client = None
|
||||||
if (
|
if (
|
||||||
stream is not None and stream == True
|
stream is not None and stream is True
|
||||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||||
print_verbose("makes async anthropic streaming POST request")
|
print_verbose("makes async anthropic streaming POST request")
|
||||||
data["stream"] = stream
|
data["stream"] = stream
|
||||||
|
@ -521,6 +446,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
base_model=base_model,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if client is None or not isinstance(client, HTTPHandler):
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
@ -562,7 +488,12 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DatabricksError(status_code=500, message=str(e))
|
raise DatabricksError(status_code=500, message=str(e))
|
||||||
|
|
||||||
return ModelResponse(**response_json)
|
response = ModelResponse(**response_json)
|
||||||
|
|
||||||
|
if base_model is not None:
|
||||||
|
response._hidden_params["model"] = base_model
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
async def aembedding(
|
async def aembedding(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -939,6 +939,8 @@ async def test_partner_models_httpx(model, sync_mode):
|
||||||
response_format_tests(response=response)
|
response_format_tests(response=response)
|
||||||
|
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
assert response._hidden_params["response_cost"] > 0
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -918,6 +918,42 @@ def test_vertex_ai_llama_predict_cost():
|
||||||
assert predictive_cost == 0
|
assert predictive_cost == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertex_ai_mistral_predict_cost():
|
||||||
|
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
|
||||||
|
response_object = ModelResponse(
|
||||||
|
id="26c0ef045020429d9c5c9b078c01e564",
|
||||||
|
choices=[
|
||||||
|
Choices(
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
message=Message(
|
||||||
|
content="Hello! I'm Litellm Bot, your helpful assistant. While I can't provide real-time weather updates, I can help you find a reliable weather service or guide you on how to check the weather on your device. Would you like assistance with that?",
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=None,
|
||||||
|
function_call=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1722124652,
|
||||||
|
model="vertex_ai/mistral-large",
|
||||||
|
object="chat.completion",
|
||||||
|
system_fingerprint=None,
|
||||||
|
usage=Usage(prompt_tokens=32, completion_tokens=55, total_tokens=87),
|
||||||
|
)
|
||||||
|
model = "mistral-large@2407"
|
||||||
|
messages = [{"role": "user", "content": "Hey, hows it going???"}]
|
||||||
|
custom_llm_provider = "vertex_ai"
|
||||||
|
predictive_cost = completion_cost(
|
||||||
|
completion_response=response_object,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert predictive_cost > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["openai/tts-1", "azure/tts-1"])
|
@pytest.mark.parametrize("model", ["openai/tts-1", "azure/tts-1"])
|
||||||
def test_completion_cost_tts(model):
|
def test_completion_cost_tts(model):
|
||||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue