mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(triton/completion/transformation.py): remove bad_words / stop wor… (#10163)
* fix(triton/completion/transformation.py): remove bad_words / stop words from triton call parameter 'bad_words' has invalid type. It should be either 'int', 'bool', or 'string'. * fix(proxy_track_cost_callback.py): add debug logging for track cost callback error
This commit is contained in:
parent
72cf30c081
commit
d73048ac46
4 changed files with 49 additions and 4 deletions
|
@ -201,8 +201,6 @@ class TritonGenerateConfig(TritonConfig):
|
||||||
"max_tokens": int(
|
"max_tokens": int(
|
||||||
optional_params.get("max_tokens", DEFAULT_MAX_TOKENS_FOR_TRITON)
|
optional_params.get("max_tokens", DEFAULT_MAX_TOKENS_FOR_TRITON)
|
||||||
),
|
),
|
||||||
"bad_words": [""],
|
|
||||||
"stop_words": [""],
|
|
||||||
},
|
},
|
||||||
"stream": bool(stream),
|
"stream": bool(stream),
|
||||||
}
|
}
|
||||||
|
|
|
@ -199,9 +199,13 @@ class _ProxyDBLogger(CustomLogger):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
|
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
|
||||||
model = kwargs.get("model", "")
|
model = kwargs.get("model", "")
|
||||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||||
|
litellm_metadata = kwargs.get("litellm_params", {}).get(
|
||||||
|
"litellm_metadata", {}
|
||||||
|
)
|
||||||
|
old_metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||||
call_type = kwargs.get("call_type", "")
|
call_type = kwargs.get("call_type", "")
|
||||||
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n call_type: {call_type}\n"
|
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n chosen_metadata: {metadata}\n litellm_metadata: {litellm_metadata}\n old_metadata: {old_metadata}\n call_type: {call_type}\n"
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
proxy_logging_obj.failed_tracking_alert(
|
proxy_logging_obj.failed_tracking_alert(
|
||||||
error_message=error_msg,
|
error_message=error_msg,
|
||||||
|
|
22
tests/litellm/litellm_core_utils/test_core_helpers.py
Normal file
22
tests/litellm/litellm_core_utils/test_core_helpers.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_litellm_metadata_from_kwargs():
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"litellm_metadata": {},
|
||||||
|
"metadata": {"user_api_key": "1234567890"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert get_litellm_metadata_from_kwargs(kwargs) == {"user_api_key": "1234567890"}
|
|
@ -20,6 +20,7 @@ from litellm.llms.triton.embedding.transformation import TritonEmbeddingConfig
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_split_embedding_by_shape_passes():
|
def test_split_embedding_by_shape_passes():
|
||||||
try:
|
try:
|
||||||
data = [
|
data = [
|
||||||
|
@ -230,3 +231,23 @@ async def test_triton_embeddings():
|
||||||
assert response.data[0]["embedding"] == [0.1, 0.2]
|
assert response.data[0]["embedding"] == [0.1, 0.2]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_triton_generate_raw_request():
|
||||||
|
from litellm.utils import return_raw_request
|
||||||
|
from litellm.types.utils import CallTypes
|
||||||
|
try:
|
||||||
|
kwargs = {
|
||||||
|
"model": "triton/llama-3-8b-instruct",
|
||||||
|
"messages": [{"role": "user", "content": "who are u?"}],
|
||||||
|
"api_base": "http://localhost:8000/generate",
|
||||||
|
}
|
||||||
|
raw_request = return_raw_request(endpoint=CallTypes.completion, kwargs=kwargs)
|
||||||
|
print("raw_request", raw_request)
|
||||||
|
assert raw_request is not None
|
||||||
|
assert "bad_words" not in json.dumps(raw_request["raw_request_body"])
|
||||||
|
assert "stop_words" not in json.dumps(raw_request["raw_request_body"])
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue