diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index db410d986..4bf60adc9 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -166,13 +166,24 @@ class LangFuseLogger: input, response_obj, ): - trace = self.Langfuse.trace( - name=metadata.get("generation_name", "litellm-completion"), - input=input, - output=output, - user_id=metadata.get("trace_user_id", user_id), - id=metadata.get("trace_id", None), - ) + import langfuse + + tags = [] + supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3") + + trace_params = { + "name": metadata.get("generation_name", "litellm-completion"), + "input": input, + "output": output, + "user_id": metadata.get("trace_user_id", user_id), + "id": metadata.get("trace_id", None), + } + if supports_tags: + for key, value in metadata.items(): + tags.append(f"{key}:{value}") + trace_params.update({"tags": tags}) + + trace = self.Langfuse.trace(**trace_params) trace.generation( name=metadata.get("generation_name", "litellm-completion"), diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1e11e8615..e014b1b6b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1423,16 +1423,14 @@ async def completion( ) if user_model: data["model"] = user_model - if "metadata" in data: - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["headers"] = dict(request.headers) - else: - data["metadata"] = { - "user_api_key": user_api_key_dict.api_key, - "user_api_key_user_id": user_api_key_dict.user_id, - } - data["metadata"]["headers"] = dict(request.headers) + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["headers"] = dict(request.headers) + data["metadata"]["endpoint"] = str(request.url) + # override with user settings, these are params passed via cli if user_temperature: data["temperature"] = user_temperature @@ -1584,15 +1582,13 @@ async def chat_completion( # if users are using user_api_key_auth, set `user` in `data` data["user"] = user_api_key_dict.user_id - if "metadata" in data: - verbose_proxy_logger.debug(f'received metadata: {data["metadata"]}') - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["headers"] = dict(request.headers) - else: - data["metadata"] = {"user_api_key": user_api_key_dict.api_key} - data["metadata"]["headers"] = dict(request.headers) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + data["metadata"]["headers"] = dict(request.headers) + data["metadata"]["endpoint"] = str(request.url) global user_temperature, user_request_timeout, user_max_tokens, user_api_base # override with user settings, these are params passed via cli @@ -1754,14 +1750,13 @@ async def embeddings( ) if user_model: data["model"] = user_model - if "metadata" in data: - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["headers"] = dict(request.headers) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - else: - data["metadata"] = {"user_api_key": user_api_key_dict.api_key} - data["metadata"]["headers"] = dict(request.headers) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + data["metadata"]["headers"] = dict(request.headers) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["endpoint"] = str(request.url) router_model_names = ( [m["model_name"] for m in llm_model_list] @@ -1895,14 +1890,14 @@ async def image_generation( ) if user_model: data["model"] = user_model - if "metadata" in data: - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["headers"] = dict(request.headers) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - else: - data["metadata"] = {"user_api_key": user_api_key_dict.api_key} - data["metadata"]["headers"] = dict(request.headers) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + data["metadata"]["headers"] = dict(request.headers) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["endpoint"] = str(request.url) router_model_names = ( [m["model_name"] for m in llm_model_list] @@ -2471,15 +2466,13 @@ async def async_queue_request( # if users are using user_api_key_auth, set `user` in `data` data["user"] = user_api_key_dict.user_id - if "metadata" in data: - verbose_proxy_logger.debug(f'received metadata: {data["metadata"]}') - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["headers"] = dict(request.headers) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - else: - data["metadata"] = {"user_api_key": user_api_key_dict.api_key} - data["metadata"]["headers"] = dict(request.headers) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + data["metadata"]["headers"] = dict(request.headers) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["endpoint"] = str(request.url) global user_temperature, user_request_timeout, user_max_tokens, user_api_base # override with user settings, these are params passed via cli diff --git a/litellm/tests/test_proxy_custom_logger.py b/litellm/tests/test_proxy_custom_logger.py index e47351a9b..34e427ef4 100644 --- a/litellm/tests/test_proxy_custom_logger.py +++ b/litellm/tests/test_proxy_custom_logger.py @@ -182,6 +182,7 @@ def test_chat_completion(client): print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata")) assert metadata is not None assert "user_api_key" in metadata + assert "user_api_key_metadata" in metadata assert "headers" in metadata config_model_info = litellm_params.get("model_info") proxy_server_request_object = litellm_params.get("proxy_server_request")