mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(databricks.py): support vertex mistral cost tracking
This commit is contained in:
parent
6f9c29d39b
commit
6d5aedc48d
4 changed files with 56 additions and 87 deletions
|
@ -509,7 +509,7 @@ def completion_cost(
|
|||
):
|
||||
model = completion_response._hidden_params.get("model", model)
|
||||
custom_llm_provider = completion_response._hidden_params.get(
|
||||
"custom_llm_provider", ""
|
||||
"custom_llm_provider", custom_llm_provider or ""
|
||||
)
|
||||
region_name = completion_response._hidden_params.get(
|
||||
"region_name", region_name
|
||||
|
@ -732,7 +732,7 @@ def response_cost_calculator(
|
|||
)
|
||||
return response_cost
|
||||
except litellm.NotFoundError as e:
|
||||
print_verbose(
|
||||
verbose_logger.debug( # debug since it can be spammy in logs, for calls
|
||||
f"Model={model} for LLM Provider={custom_llm_provider} not found in completion cost map."
|
||||
)
|
||||
return None
|
||||
|
|
|
@ -259,87 +259,6 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
api_base = "{}/embeddings".format(api_base)
|
||||
return api_base, headers
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise DatabricksError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
if "error" in completion_response:
|
||||
raise DatabricksError(
|
||||
message=str(completion_response["error"]),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
for content in completion_response["content"]:
|
||||
if content["type"] == "text":
|
||||
text_content += content["text"]
|
||||
## TOOL CALLING
|
||||
elif content["type"] == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": content["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": content["name"],
|
||||
"arguments": json.dumps(content["input"]),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=text_content or None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = completion_response[
|
||||
"content"
|
||||
] # allow user to access raw anthropic tool calling response
|
||||
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["stop_reason"]
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage) # type: ignore
|
||||
return model_response
|
||||
|
||||
async def acompletion_stream_function(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -392,6 +311,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
logging_obj,
|
||||
stream,
|
||||
data: dict,
|
||||
base_model: Optional[str],
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -420,7 +340,11 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
except Exception as e:
|
||||
raise DatabricksError(status_code=500, message=str(e))
|
||||
|
||||
return ModelResponse(**response_json)
|
||||
response = ModelResponse(**response_json)
|
||||
|
||||
if base_model is not None:
|
||||
response._hidden_params["model"] = base_model
|
||||
return response
|
||||
|
||||
def completion(
|
||||
self,
|
||||
|
@ -443,6 +367,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None)
|
||||
base_model: Optional[str] = optional_params.pop("base_model", None)
|
||||
api_base, headers = self._validate_environment(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
|
@ -476,11 +401,11 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
"headers": headers,
|
||||
},
|
||||
)
|
||||
if acompletion == True:
|
||||
if acompletion is True:
|
||||
if client is not None and isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if (
|
||||
stream is not None and stream == True
|
||||
stream is not None and stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
print_verbose("makes async anthropic streaming POST request")
|
||||
data["stream"] = stream
|
||||
|
@ -521,6 +446,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
base_model=base_model,
|
||||
)
|
||||
else:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
|
@ -562,7 +488,12 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
except Exception as e:
|
||||
raise DatabricksError(status_code=500, message=str(e))
|
||||
|
||||
return ModelResponse(**response_json)
|
||||
response = ModelResponse(**response_json)
|
||||
|
||||
if base_model is not None:
|
||||
response._hidden_params["model"] = base_model
|
||||
|
||||
return response
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
|
|
|
@ -939,6 +939,8 @@ async def test_partner_models_httpx(model, sync_mode):
|
|||
response_format_tests(response=response)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
assert response._hidden_params["response_cost"] > 0
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
|
|
@ -918,6 +918,42 @@ def test_vertex_ai_llama_predict_cost():
|
|||
assert predictive_cost == 0
|
||||
|
||||
|
||||
def test_vertex_ai_mistral_predict_cost():
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||
|
||||
response_object = ModelResponse(
|
||||
id="26c0ef045020429d9c5c9b078c01e564",
|
||||
choices=[
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=Message(
|
||||
content="Hello! I'm Litellm Bot, your helpful assistant. While I can't provide real-time weather updates, I can help you find a reliable weather service or guide you on how to check the weather on your device. Would you like assistance with that?",
|
||||
role="assistant",
|
||||
tool_calls=None,
|
||||
function_call=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1722124652,
|
||||
model="vertex_ai/mistral-large",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=Usage(prompt_tokens=32, completion_tokens=55, total_tokens=87),
|
||||
)
|
||||
model = "mistral-large@2407"
|
||||
messages = [{"role": "user", "content": "Hey, hows it going???"}]
|
||||
custom_llm_provider = "vertex_ai"
|
||||
predictive_cost = completion_cost(
|
||||
completion_response=response_object,
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
assert predictive_cost > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["openai/tts-1", "azure/tts-1"])
|
||||
def test_completion_cost_tts(model):
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue