forked from phoenix/litellm-mirror
fix(main.py): pass user_id + encoding_format for logging + to openai/azure
This commit is contained in:
parent
35fa176c97
commit
8b07a6c046
10 changed files with 82 additions and 26 deletions
BIN
dist/litellm-1.12.6.dev5-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.6.dev5-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev5.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.6.dev5.tar.gz
vendored
Normal file
Binary file not shown.
|
@ -415,7 +415,6 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
|
|
@ -31,7 +31,8 @@ from litellm.utils import (
|
|||
mock_completion_streaming_obj,
|
||||
convert_to_model_response_object,
|
||||
token_counter,
|
||||
Usage
|
||||
Usage,
|
||||
get_optional_params_embeddings
|
||||
)
|
||||
from .llms import (
|
||||
anthropic,
|
||||
|
@ -1828,18 +1829,16 @@ def embedding(
|
|||
tpm = kwargs.pop("tpm", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", None)
|
||||
encoding_format = kwargs.get("encoding_format", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
aembedding = kwargs.pop("aembedding", None)
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
|
||||
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "encoding_format"]
|
||||
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
optional_params = {}
|
||||
for param in non_default_params:
|
||||
optional_params[param] = kwargs[param]
|
||||
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
|
||||
|
||||
|
||||
optional_params = get_optional_params_embeddings(user=user, encoding_format=encoding_format, custom_llm_provider=custom_llm_provider, **non_default_params)
|
||||
try:
|
||||
response = None
|
||||
logging = litellm_logging_obj
|
||||
|
|
|
@ -1,3 +1,11 @@
|
|||
import sys, os, traceback
|
||||
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import litellm
|
||||
import inspect
|
||||
|
|
|
@ -916,7 +916,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
except:
|
||||
data = json.loads(body_str)
|
||||
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
data["user"] = data.get("user", user_api_key_dict.user_id)
|
||||
data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
|
@ -1066,7 +1066,7 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
"body": copy.copy(data) # use copy instead of deepcopy
|
||||
}
|
||||
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
data["user"] = data.get("user", user_api_key_dict.user_id)
|
||||
data["model"] = (
|
||||
general_settings.get("embedding_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
|
@ -1081,7 +1081,6 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||
data["metadata"]["headers"] = dict(request.headers)
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
print(f"received data: {data['input']}")
|
||||
if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in
|
||||
# check if non-openai/azure model called - e.g. for langchain integration
|
||||
if llm_model_list is not None and data["model"] in router_model_names:
|
||||
|
@ -1099,7 +1098,7 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
|
||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
|
||||
|
||||
print(f'final data: {data}')
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
response = await llm_router.aembedding(**data)
|
||||
|
|
|
@ -197,6 +197,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
@ -507,7 +508,7 @@ async def test_async_embedding_openai():
|
|||
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
|
|
|
@ -26,9 +26,12 @@ class MyCustomHandler(CustomLogger):
|
|||
|
||||
self.stream_collected_response = None # type: ignore
|
||||
self.sync_stream_collected_response = None # type: ignore
|
||||
self.user = None # type: ignore
|
||||
self.data_sent_to_api: dict = {}
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
print(f"Pre-API Call")
|
||||
self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {})
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"Post-API Call")
|
||||
|
@ -49,6 +52,7 @@ class MyCustomHandler(CustomLogger):
|
|||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async success")
|
||||
print(f"received kwargs user: {kwargs['user']}")
|
||||
self.async_success = True
|
||||
if kwargs.get("model") == "text-embedding-ada-002":
|
||||
self.async_success_embedding = True
|
||||
|
@ -57,6 +61,7 @@ class MyCustomHandler(CustomLogger):
|
|||
if kwargs.get("stream") == True:
|
||||
self.stream_collected_response = response_obj
|
||||
self.async_completion_kwargs = kwargs
|
||||
self.user = kwargs.get("user", None)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Failure")
|
||||
|
@ -95,14 +100,6 @@ def test_async_chat_openai_stream():
|
|||
print(complete_streaming_response)
|
||||
asyncio.run(call_gpt())
|
||||
complete_streaming_response = complete_streaming_response.strip("'")
|
||||
print(f"complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content']}")
|
||||
print(f"type of complete_streaming_response_in_callback: {type(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
|
||||
print(f"hidden char complete_streaming_response_in_callback: {repr(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
|
||||
print(f"encoding complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'].encode('utf-8')}")
|
||||
print(f"complete_streaming_response: {complete_streaming_response}")
|
||||
print(f"type(complete_streaming_response): {type(complete_streaming_response)}")
|
||||
print(f"hidden char complete_streaming_response): {repr(complete_streaming_response)}")
|
||||
print(f"encoding complete_streaming_response): {repr(complete_streaming_response).encode('utf-8')}")
|
||||
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0]["message"]["content"]
|
||||
response2 = complete_streaming_response
|
||||
assert [ord(c) for c in response1] == [ord(c) for c in response2]
|
||||
|
@ -110,7 +107,7 @@ def test_async_chat_openai_stream():
|
|||
except Exception as e:
|
||||
print(e)
|
||||
pytest.fail(f"An error occurred - {str(e)}")
|
||||
test_async_chat_openai_stream()
|
||||
# test_async_chat_openai_stream()
|
||||
|
||||
def test_completion_azure_stream_moderation_failure():
|
||||
try:
|
||||
|
@ -290,9 +287,29 @@ async def test_async_custom_handler_embedding():
|
|||
assert len(str(customHandler_embedding.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
asyncio.run(test_async_custom_handler_embedding())
|
||||
from litellm import Cache
|
||||
# asyncio.run(test_async_custom_handler_embedding())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_custom_handler_embedding_optional_param():
|
||||
"""
|
||||
Tests if the openai optional params for embedding - user + encoding_format,
|
||||
are logged
|
||||
"""
|
||||
customHandler_optional_params = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler_optional_params]
|
||||
response = await litellm.aembedding(
|
||||
model="text-embedding-ada-002",
|
||||
input = ["hello world"],
|
||||
user = "John"
|
||||
)
|
||||
await asyncio.sleep(1) # success callback is async
|
||||
assert customHandler_optional_params.user == "John"
|
||||
assert customHandler_optional_params.user == customHandler_optional_params.data_sent_to_api["user"]
|
||||
|
||||
# asyncio.run(test_async_custom_handler_embedding_optional_param())
|
||||
|
||||
def test_redis_cache_completion_stream():
|
||||
from litellm import Cache
|
||||
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set
|
||||
import random
|
||||
try:
|
||||
|
@ -325,4 +342,4 @@ def test_redis_cache_completion_stream():
|
|||
print(e)
|
||||
litellm.success_callback = []
|
||||
raise e
|
||||
test_redis_cache_completion_stream()
|
||||
# test_redis_cache_completion_stream()
|
|
@ -2112,6 +2112,39 @@ def get_litellm_params(
|
|||
return litellm_params
|
||||
|
||||
|
||||
def get_optional_params_embeddings(
|
||||
# 2 optional params
|
||||
user=None,
|
||||
encoding_format=None,
|
||||
custom_llm_provider="",
|
||||
**kwargs
|
||||
):
|
||||
# retrieve all parameters passed to the function
|
||||
passed_params = locals()
|
||||
custom_llm_provider = passed_params.pop("custom_llm_provider", None)
|
||||
special_params = passed_params.pop("kwargs")
|
||||
for k, v in special_params.items():
|
||||
passed_params[k] = v
|
||||
|
||||
default_params = {
|
||||
"user": None,
|
||||
"encoding_format": None
|
||||
}
|
||||
|
||||
non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])}
|
||||
|
||||
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||
if custom_llm_provider != "openai" and custom_llm_provider != "azure":
|
||||
if len(non_default_params.keys()) > 0:
|
||||
if litellm.drop_params is True:
|
||||
for k in non_default_params.keys():
|
||||
passed_params.pop(k, None)
|
||||
return passed_params
|
||||
raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.")
|
||||
|
||||
final_params = {**non_default_params, **kwargs}
|
||||
return final_params
|
||||
|
||||
def get_optional_params( # use the openai defaults
|
||||
# 12 optional params
|
||||
functions=[],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue