feat(databricks.py): support vertex mistral cost tracking

This commit is contained in:
Krrish Dholakia 2024-07-27 20:22:35 -07:00
parent 6f9c29d39b
commit 6d5aedc48d
4 changed files with 56 additions and 87 deletions

View file

@ -509,7 +509,7 @@ def completion_cost(
):
model = completion_response._hidden_params.get("model", model)
custom_llm_provider = completion_response._hidden_params.get(
"custom_llm_provider", ""
"custom_llm_provider", custom_llm_provider or ""
)
region_name = completion_response._hidden_params.get(
"region_name", region_name
@ -732,7 +732,7 @@ def response_cost_calculator(
)
return response_cost
except litellm.NotFoundError as e:
print_verbose(
verbose_logger.debug( # debug since it can be spammy in logs, for calls
f"Model={model} for LLM Provider={custom_llm_provider} not found in completion cost map."
)
return None

View file

@ -259,87 +259,6 @@ class DatabricksChatCompletion(BaseLLM):
api_base = "{}/embeddings".format(api_base)
return api_base, headers
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise DatabricksError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise DatabricksError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
else:
text_content = ""
tool_calls = []
for content in completion_response["content"]:
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
{
"id": content["id"],
"type": "function",
"function": {
"name": content["name"],
"arguments": json.dumps(content["input"]),
},
}
)
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
] # allow user to access raw anthropic tool calling response
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
total_tokens = prompt_tokens + completion_tokens
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
setattr(model_response, "usage", usage) # type: ignore
return model_response
async def acompletion_stream_function(
self,
model: str,
@ -392,6 +311,7 @@ class DatabricksChatCompletion(BaseLLM):
logging_obj,
stream,
data: dict,
base_model: Optional[str],
optional_params: dict,
litellm_params=None,
logger_fn=None,
@ -420,7 +340,11 @@ class DatabricksChatCompletion(BaseLLM):
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))
return ModelResponse(**response_json)
response = ModelResponse(**response_json)
if base_model is not None:
response._hidden_params["model"] = base_model
return response
def completion(
self,
@ -443,6 +367,7 @@ class DatabricksChatCompletion(BaseLLM):
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None)
base_model: Optional[str] = optional_params.pop("base_model", None)
api_base, headers = self._validate_environment(
api_base=api_base,
api_key=api_key,
@ -476,11 +401,11 @@ class DatabricksChatCompletion(BaseLLM):
"headers": headers,
},
)
if acompletion == True:
if acompletion is True:
if client is not None and isinstance(client, HTTPHandler):
client = None
if (
stream is not None and stream == True
stream is not None and stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes async anthropic streaming POST request")
data["stream"] = stream
@ -521,6 +446,7 @@ class DatabricksChatCompletion(BaseLLM):
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
base_model=base_model,
)
else:
if client is None or not isinstance(client, HTTPHandler):
@ -562,7 +488,12 @@ class DatabricksChatCompletion(BaseLLM):
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))
return ModelResponse(**response_json)
response = ModelResponse(**response_json)
if base_model is not None:
response._hidden_params["model"] = base_model
return response
async def aembedding(
self,

View file

@ -939,6 +939,8 @@ async def test_partner_models_httpx(model, sync_mode):
response_format_tests(response=response)
print(f"response: {response}")
assert response._hidden_params["response_cost"] > 0
except litellm.RateLimitError as e:
pass
except Exception as e:

View file

@ -918,6 +918,42 @@ def test_vertex_ai_llama_predict_cost():
assert predictive_cost == 0
def test_vertex_ai_mistral_predict_cost():
from litellm.types.utils import Choices, Message, ModelResponse, Usage
response_object = ModelResponse(
id="26c0ef045020429d9c5c9b078c01e564",
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="Hello! I'm Litellm Bot, your helpful assistant. While I can't provide real-time weather updates, I can help you find a reliable weather service or guide you on how to check the weather on your device. Would you like assistance with that?",
role="assistant",
tool_calls=None,
function_call=None,
),
)
],
created=1722124652,
model="vertex_ai/mistral-large",
object="chat.completion",
system_fingerprint=None,
usage=Usage(prompt_tokens=32, completion_tokens=55, total_tokens=87),
)
model = "mistral-large@2407"
messages = [{"role": "user", "content": "Hey, hows it going???"}]
custom_llm_provider = "vertex_ai"
predictive_cost = completion_cost(
completion_response=response_object,
model=model,
messages=messages,
custom_llm_provider=custom_llm_provider,
)
assert predictive_cost > 0
@pytest.mark.parametrize("model", ["openai/tts-1", "azure/tts-1"])
def test_completion_cost_tts(model):
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"