test: fix linting error

This commit is contained in:
Krrish Dholakia 2024-05-07 13:18:49 -07:00
parent 724660606a
commit b872da4e6f

View file

@ -21,7 +21,7 @@ import pytest
@pytest.fixture @pytest.fixture
def langfuse_client() -> "langfuse.Langfuse": def langfuse_client():
import langfuse import langfuse
langfuse_client = langfuse.Langfuse( langfuse_client = langfuse.Langfuse(
@ -29,9 +29,12 @@ def langfuse_client() -> "langfuse.Langfuse":
secret_key=os.environ["LANGFUSE_SECRET_KEY"], secret_key=os.environ["LANGFUSE_SECRET_KEY"],
) )
with patch("langfuse.Langfuse", MagicMock(return_value=langfuse_client)) as mock_langfuse_client: with patch(
"langfuse.Langfuse", MagicMock(return_value=langfuse_client)
) as mock_langfuse_client:
yield mock_langfuse_client() yield mock_langfuse_client()
def search_logs(log_file_path, num_good_logs=1): def search_logs(log_file_path, num_good_logs=1):
""" """
Searches the given log file for logs containing the "/api/public" string. Searches the given log file for logs containing the "/api/public" string.
@ -143,7 +146,7 @@ def test_langfuse_logging_async():
pytest.fail(f"An exception occurred - {e}") pytest.fail(f"An exception occurred - {e}")
async def make_async_calls(metadata = None, **completion_kwargs): async def make_async_calls(metadata=None, **completion_kwargs):
tasks = [] tasks = []
for _ in range(5): for _ in range(5):
tasks.append(create_async_task()) tasks.append(create_async_task())
@ -173,14 +176,14 @@ def create_async_task(**completion_kwargs):
By default a standard set of arguments are used for the litellm.acompletion function. By default a standard set of arguments are used for the litellm.acompletion function.
""" """
completion_args = { completion_args = {
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"messages": [{"role": "user", "content": "This is a test"}], "messages": [{"role": "user", "content": "This is a test"}],
"max_tokens": 5, "max_tokens": 5,
"temperature": 0.7, "temperature": 0.7,
"timeout": 5, "timeout": 5,
"user": "langfuse_latency_test_user", "user": "langfuse_latency_test_user",
"mock_response": "It's simple to use and easy to get started", "mock_response": "It's simple to use and easy to get started",
} }
completion_args.update(completion_kwargs) completion_args.update(completion_kwargs)
return asyncio.create_task(litellm.acompletion(**completion_args)) return asyncio.create_task(litellm.acompletion(**completion_args))
@ -195,7 +198,11 @@ async def test_langfuse_logging_without_request_response(stream, langfuse_client
litellm.set_verbose = True litellm.set_verbose = True
litellm.turn_off_message_logging = True litellm.turn_off_message_logging = True
litellm.success_callback = ["langfuse"] litellm.success_callback = ["langfuse"]
response = await create_async_task(model="gpt-3.5-turbo", stream=stream, metadata={"trace_id": _unique_trace_name}) response = await create_async_task(
model="gpt-3.5-turbo",
stream=stream,
metadata={"trace_id": _unique_trace_name},
)
print(response) print(response)
if stream: if stream:
async for chunk in response: async for chunk in response:
@ -232,49 +239,78 @@ async def test_langfuse_logging_metadata(langfuse_client):
Tags is just set for the trace Tags is just set for the trace
""" """
import uuid import uuid
litellm.set_verbose = True litellm.set_verbose = True
litellm.success_callback = ["langfuse"] litellm.success_callback = ["langfuse"]
trace_identifiers = {} trace_identifiers = {}
expected_filtered_metadata_keys = {"trace_name", "trace_id", "existing_trace_id", "trace_user_id", "session_id", "tags", "generation_name", "generation_id", "prompt"} expected_filtered_metadata_keys = {
trace_metadata = {"trace_actual_metadata_key": "trace_actual_metadata_value"} # Allows for setting the metadata on the trace "trace_name",
"trace_id",
"existing_trace_id",
"trace_user_id",
"session_id",
"tags",
"generation_name",
"generation_id",
"prompt",
}
trace_metadata = {
"trace_actual_metadata_key": "trace_actual_metadata_value"
} # Allows for setting the metadata on the trace
run_id = str(uuid.uuid4()) run_id = str(uuid.uuid4())
session_id = f"litellm-test-session-{run_id}" session_id = f"litellm-test-session-{run_id}"
trace_common_metadata = { trace_common_metadata = {
"session_id": session_id, "session_id": session_id,
"tags": ["litellm-test-tag1", "litellm-test-tag2"], "tags": ["litellm-test-tag1", "litellm-test-tag2"],
"update_trace_keys": ["output", "trace_metadata"], # Overwrite the following fields in the trace with the last generation's output and the trace_user_id "update_trace_keys": [
"output",
"trace_metadata",
], # Overwrite the following fields in the trace with the last generation's output and the trace_user_id
"trace_metadata": trace_metadata, "trace_metadata": trace_metadata,
"gen_metadata_key": "gen_metadata_value", # Metadata key that should not be filtered in the generation "gen_metadata_key": "gen_metadata_value", # Metadata key that should not be filtered in the generation
"trace_release": "litellm-test-release", "trace_release": "litellm-test-release",
"version": "litellm-test-version", "version": "litellm-test-version",
} }
for trace_num in range(1, 3): # Two traces for trace_num in range(1, 3): # Two traces
metadata = copy.deepcopy(trace_common_metadata) metadata = copy.deepcopy(trace_common_metadata)
trace_id = f"litellm-test-trace{trace_num}-{run_id}" trace_id = f"litellm-test-trace{trace_num}-{run_id}"
metadata["trace_id"] = trace_id metadata["trace_id"] = trace_id
metadata["trace_name"] = trace_id metadata["trace_name"] = trace_id
trace_identifiers[trace_id] = [] trace_identifiers[trace_id] = []
print(f"Trace: {trace_id}") print(f"Trace: {trace_id}")
for generation_num in range(1, trace_num + 1): # Each trace has a number of generations equal to its trace number for generation_num in range(
1, trace_num + 1
): # Each trace has a number of generations equal to its trace number
metadata["trace_user_id"] = f"litellm-test-user{generation_num}-{run_id}" metadata["trace_user_id"] = f"litellm-test-user{generation_num}-{run_id}"
generation_id = f"litellm-test-trace{trace_num}-generation-{generation_num}-{run_id}" generation_id = (
f"litellm-test-trace{trace_num}-generation-{generation_num}-{run_id}"
)
metadata["generation_id"] = generation_id metadata["generation_id"] = generation_id
metadata["generation_name"] = generation_id metadata["generation_name"] = generation_id
metadata["trace_metadata"]["generation_id"] = generation_id # Update to test if trace_metadata is overwritten by update trace keys metadata["trace_metadata"][
"generation_id"
] = generation_id # Update to test if trace_metadata is overwritten by update trace keys
trace_identifiers[trace_id].append(generation_id) trace_identifiers[trace_id].append(generation_id)
print(f"Generation: {generation_id}") print(f"Generation: {generation_id}")
response = await create_async_task(model="gpt-3.5-turbo", response = await create_async_task(
model="gpt-3.5-turbo",
mock_response=f"{session_id}:{trace_id}:{generation_id}", mock_response=f"{session_id}:{trace_id}:{generation_id}",
messages=[{"role": "user", "content": f"{session_id}:{trace_id}:{generation_id}"}], messages=[
{
"role": "user",
"content": f"{session_id}:{trace_id}:{generation_id}",
}
],
max_tokens=100, max_tokens=100,
temperature=0.2, temperature=0.2,
metadata=copy.deepcopy(metadata) # Every generation needs its own metadata, langfuse is not async/thread safe without it metadata=copy.deepcopy(
metadata
), # Every generation needs its own metadata, langfuse is not async/thread safe without it
) )
print(response) print(response)
metadata["existing_trace_id"] = trace_id metadata["existing_trace_id"] = trace_id
langfuse_client.flush() langfuse_client.flush()
await asyncio.sleep(2) await asyncio.sleep(2)
@ -284,20 +320,31 @@ async def test_langfuse_logging_metadata(langfuse_client):
assert trace.id == trace_id assert trace.id == trace_id
assert trace.session_id == session_id assert trace.session_id == session_id
assert trace.metadata != trace_metadata assert trace.metadata != trace_metadata
generations = list(reversed(langfuse_client.get_generations(trace_id=trace_id).data)) generations = list(
reversed(langfuse_client.get_generations(trace_id=trace_id).data)
)
assert len(generations) == len(generation_ids) assert len(generations) == len(generation_ids)
assert trace.input == generations[0].input # Should be set by the first generation assert (
assert trace.output == generations[-1].output # Should be overwritten by the last generation according to update_trace_keys trace.input == generations[0].input
assert trace.metadata != generations[-1].metadata # Should be overwritten by the last generation according to update_trace_keys ) # Should be set by the first generation
assert (
trace.output == generations[-1].output
) # Should be overwritten by the last generation according to update_trace_keys
assert (
trace.metadata != generations[-1].metadata
) # Should be overwritten by the last generation according to update_trace_keys
assert trace.metadata["generation_id"] == generations[-1].id assert trace.metadata["generation_id"] == generations[-1].id
assert set(trace.tags).issuperset(trace_common_metadata["tags"]) assert set(trace.tags).issuperset(trace_common_metadata["tags"])
print("trace_from_langfuse", trace) print("trace_from_langfuse", trace)
for generation_id, generation in zip(generation_ids, generations): for generation_id, generation in zip(generation_ids, generations):
assert generation.id == generation_id assert generation.id == generation_id
assert generation.trace_id == trace_id assert generation.trace_id == trace_id
assert set(generation.metadata.keys()).isdisjoint(expected_filtered_metadata_keys) assert set(generation.metadata.keys()).isdisjoint(
expected_filtered_metadata_keys
)
print("generation_from_langfuse", generation) print("generation_from_langfuse", generation)
@pytest.mark.skip(reason="beta test - checking langfuse output") @pytest.mark.skip(reason="beta test - checking langfuse output")
def test_langfuse_logging(): def test_langfuse_logging():
try: try:
@ -657,7 +704,10 @@ def test_langfuse_existing_trace_id():
assert initial_langfuse_trace_dict == new_langfuse_trace_dict assert initial_langfuse_trace_dict == new_langfuse_trace_dict
@pytest.mark.skipif(condition=not os.environ.get("OPENAI_API_KEY", False), reason="Authentication missing for openai") @pytest.mark.skipif(
condition=not os.environ.get("OPENAI_API_KEY", False),
reason="Authentication missing for openai",
)
def test_langfuse_logging_tool_calling(): def test_langfuse_logging_tool_calling():
litellm.set_verbose = True litellm.set_verbose = True