feat(databricks.py): support vertex mistral cost tracking

This commit is contained in:
Krrish Dholakia 2024-07-27 20:22:35 -07:00
parent 6f9c29d39b
commit 6d5aedc48d
4 changed files with 56 additions and 87 deletions

View file

@ -259,87 +259,6 @@ class DatabricksChatCompletion(BaseLLM):
api_base = "{}/embeddings".format(api_base)
return api_base, headers
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise DatabricksError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise DatabricksError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
else:
text_content = ""
tool_calls = []
for content in completion_response["content"]:
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
{
"id": content["id"],
"type": "function",
"function": {
"name": content["name"],
"arguments": json.dumps(content["input"]),
},
}
)
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
] # allow user to access raw anthropic tool calling response
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
total_tokens = prompt_tokens + completion_tokens
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
setattr(model_response, "usage", usage) # type: ignore
return model_response
async def acompletion_stream_function(
self,
model: str,
@ -392,6 +311,7 @@ class DatabricksChatCompletion(BaseLLM):
logging_obj,
stream,
data: dict,
base_model: Optional[str],
optional_params: dict,
litellm_params=None,
logger_fn=None,
@ -420,7 +340,11 @@ class DatabricksChatCompletion(BaseLLM):
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))
return ModelResponse(**response_json)
response = ModelResponse(**response_json)
if base_model is not None:
response._hidden_params["model"] = base_model
return response
def completion(
self,
@ -443,6 +367,7 @@ class DatabricksChatCompletion(BaseLLM):
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None)
base_model: Optional[str] = optional_params.pop("base_model", None)
api_base, headers = self._validate_environment(
api_base=api_base,
api_key=api_key,
@ -476,11 +401,11 @@ class DatabricksChatCompletion(BaseLLM):
"headers": headers,
},
)
if acompletion == True:
if acompletion is True:
if client is not None and isinstance(client, HTTPHandler):
client = None
if (
stream is not None and stream == True
stream is not None and stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes async anthropic streaming POST request")
data["stream"] = stream
@ -521,6 +446,7 @@ class DatabricksChatCompletion(BaseLLM):
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
base_model=base_model,
)
else:
if client is None or not isinstance(client, HTTPHandler):
@ -562,7 +488,12 @@ class DatabricksChatCompletion(BaseLLM):
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))
return ModelResponse(**response_json)
response = ModelResponse(**response_json)
if base_model is not None:
response._hidden_params["model"] = base_model
return response
async def aembedding(
self,