forked from phoenix/litellm-mirror
Merge pull request #1484 from BerriAI/litellm_access_key_metadata_in_callbacks
[Feat] Proxy - Access Key metadata in callbacks
This commit is contained in:
commit
15ae9182db
3 changed files with 56 additions and 51 deletions
|
@ -166,13 +166,24 @@ class LangFuseLogger:
|
||||||
input,
|
input,
|
||||||
response_obj,
|
response_obj,
|
||||||
):
|
):
|
||||||
trace = self.Langfuse.trace(
|
import langfuse
|
||||||
name=metadata.get("generation_name", "litellm-completion"),
|
|
||||||
input=input,
|
tags = []
|
||||||
output=output,
|
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
|
||||||
user_id=metadata.get("trace_user_id", user_id),
|
|
||||||
id=metadata.get("trace_id", None),
|
trace_params = {
|
||||||
)
|
"name": metadata.get("generation_name", "litellm-completion"),
|
||||||
|
"input": input,
|
||||||
|
"output": output,
|
||||||
|
"user_id": metadata.get("trace_user_id", user_id),
|
||||||
|
"id": metadata.get("trace_id", None),
|
||||||
|
}
|
||||||
|
if supports_tags:
|
||||||
|
for key, value in metadata.items():
|
||||||
|
tags.append(f"{key}:{value}")
|
||||||
|
trace_params.update({"tags": tags})
|
||||||
|
|
||||||
|
trace = self.Langfuse.trace(**trace_params)
|
||||||
|
|
||||||
trace.generation(
|
trace.generation(
|
||||||
name=metadata.get("generation_name", "litellm-completion"),
|
name=metadata.get("generation_name", "litellm-completion"),
|
||||||
|
|
|
@ -1423,16 +1423,14 @@ async def completion(
|
||||||
)
|
)
|
||||||
if user_model:
|
if user_model:
|
||||||
data["model"] = user_model
|
data["model"] = user_model
|
||||||
if "metadata" in data:
|
if "metadata" not in data:
|
||||||
|
data["metadata"] = {}
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
else:
|
data["metadata"]["endpoint"] = str(request.url)
|
||||||
data["metadata"] = {
|
|
||||||
"user_api_key": user_api_key_dict.api_key,
|
|
||||||
"user_api_key_user_id": user_api_key_dict.user_id,
|
|
||||||
}
|
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
|
||||||
# override with user settings, these are params passed via cli
|
# override with user settings, these are params passed via cli
|
||||||
if user_temperature:
|
if user_temperature:
|
||||||
data["temperature"] = user_temperature
|
data["temperature"] = user_temperature
|
||||||
|
@ -1584,15 +1582,13 @@ async def chat_completion(
|
||||||
# if users are using user_api_key_auth, set `user` in `data`
|
# if users are using user_api_key_auth, set `user` in `data`
|
||||||
data["user"] = user_api_key_dict.user_id
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
if "metadata" in data:
|
if "metadata" not in data:
|
||||||
verbose_proxy_logger.debug(f'received metadata: {data["metadata"]}')
|
data["metadata"] = {}
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
else:
|
data["metadata"]["endpoint"] = str(request.url)
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
|
||||||
|
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
# override with user settings, these are params passed via cli
|
# override with user settings, these are params passed via cli
|
||||||
|
@ -1754,14 +1750,13 @@ async def embeddings(
|
||||||
)
|
)
|
||||||
if user_model:
|
if user_model:
|
||||||
data["model"] = user_model
|
data["model"] = user_model
|
||||||
if "metadata" in data:
|
if "metadata" not in data:
|
||||||
|
data["metadata"] = {}
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
else:
|
data["metadata"]["endpoint"] = str(request.url)
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
|
||||||
|
|
||||||
router_model_names = (
|
router_model_names = (
|
||||||
[m["model_name"] for m in llm_model_list]
|
[m["model_name"] for m in llm_model_list]
|
||||||
|
@ -1895,14 +1890,14 @@ async def image_generation(
|
||||||
)
|
)
|
||||||
if user_model:
|
if user_model:
|
||||||
data["model"] = user_model
|
data["model"] = user_model
|
||||||
if "metadata" in data:
|
|
||||||
|
if "metadata" not in data:
|
||||||
|
data["metadata"] = {}
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
else:
|
data["metadata"]["endpoint"] = str(request.url)
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
|
||||||
|
|
||||||
router_model_names = (
|
router_model_names = (
|
||||||
[m["model_name"] for m in llm_model_list]
|
[m["model_name"] for m in llm_model_list]
|
||||||
|
@ -2471,15 +2466,13 @@ async def async_queue_request(
|
||||||
# if users are using user_api_key_auth, set `user` in `data`
|
# if users are using user_api_key_auth, set `user` in `data`
|
||||||
data["user"] = user_api_key_dict.user_id
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
if "metadata" in data:
|
if "metadata" not in data:
|
||||||
verbose_proxy_logger.debug(f'received metadata: {data["metadata"]}')
|
data["metadata"] = {}
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
else:
|
data["metadata"]["endpoint"] = str(request.url)
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
|
||||||
|
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
# override with user settings, these are params passed via cli
|
# override with user settings, these are params passed via cli
|
||||||
|
|
|
@ -182,6 +182,7 @@ def test_chat_completion(client):
|
||||||
print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata"))
|
print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata"))
|
||||||
assert metadata is not None
|
assert metadata is not None
|
||||||
assert "user_api_key" in metadata
|
assert "user_api_key" in metadata
|
||||||
|
assert "user_api_key_metadata" in metadata
|
||||||
assert "headers" in metadata
|
assert "headers" in metadata
|
||||||
config_model_info = litellm_params.get("model_info")
|
config_model_info = litellm_params.get("model_info")
|
||||||
proxy_server_request_object = litellm_params.get("proxy_server_request")
|
proxy_server_request_object = litellm_params.get("proxy_server_request")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue