mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(databricks.py): support vertex mistral cost tracking
This commit is contained in:
parent
6f9c29d39b
commit
6d5aedc48d
4 changed files with 56 additions and 87 deletions
|
@ -259,87 +259,6 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
api_base = "{}/embeddings".format(api_base)
|
||||
return api_base, headers
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise DatabricksError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
if "error" in completion_response:
|
||||
raise DatabricksError(
|
||||
message=str(completion_response["error"]),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
for content in completion_response["content"]:
|
||||
if content["type"] == "text":
|
||||
text_content += content["text"]
|
||||
## TOOL CALLING
|
||||
elif content["type"] == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": content["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": content["name"],
|
||||
"arguments": json.dumps(content["input"]),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=text_content or None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = completion_response[
|
||||
"content"
|
||||
] # allow user to access raw anthropic tool calling response
|
||||
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["stop_reason"]
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage) # type: ignore
|
||||
return model_response
|
||||
|
||||
async def acompletion_stream_function(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -392,6 +311,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
logging_obj,
|
||||
stream,
|
||||
data: dict,
|
||||
base_model: Optional[str],
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -420,7 +340,11 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
except Exception as e:
|
||||
raise DatabricksError(status_code=500, message=str(e))
|
||||
|
||||
return ModelResponse(**response_json)
|
||||
response = ModelResponse(**response_json)
|
||||
|
||||
if base_model is not None:
|
||||
response._hidden_params["model"] = base_model
|
||||
return response
|
||||
|
||||
def completion(
|
||||
self,
|
||||
|
@ -443,6 +367,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None)
|
||||
base_model: Optional[str] = optional_params.pop("base_model", None)
|
||||
api_base, headers = self._validate_environment(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
|
@ -476,11 +401,11 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
"headers": headers,
|
||||
},
|
||||
)
|
||||
if acompletion == True:
|
||||
if acompletion is True:
|
||||
if client is not None and isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if (
|
||||
stream is not None and stream == True
|
||||
stream is not None and stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
print_verbose("makes async anthropic streaming POST request")
|
||||
data["stream"] = stream
|
||||
|
@ -521,6 +446,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
base_model=base_model,
|
||||
)
|
||||
else:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
|
@ -562,7 +488,12 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
except Exception as e:
|
||||
raise DatabricksError(status_code=500, message=str(e))
|
||||
|
||||
return ModelResponse(**response_json)
|
||||
response = ModelResponse(**response_json)
|
||||
|
||||
if base_model is not None:
|
||||
response._hidden_params["model"] = base_model
|
||||
|
||||
return response
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue