diff --git a/litellm/__init__.py b/litellm/__init__.py index 43f91fe58..95172efa0 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -17,7 +17,11 @@ from litellm._logging import ( _turn_on_json, log_level, ) -from litellm.constants import ROUTER_MAX_FALLBACKS +from litellm.constants import ( + DEFAULT_BATCH_SIZE, + DEFAULT_FLUSH_INTERVAL_SECONDS, + ROUTER_MAX_FALLBACKS, +) from litellm.types.guardrails import GuardrailItem from litellm.proxy._types import ( KeyManagementSystem, diff --git a/litellm/constants.py b/litellm/constants.py index 8d27cf564..331a0a630 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -1 +1,3 @@ ROUTER_MAX_FALLBACKS = 5 +DEFAULT_BATCH_SIZE = 512 +DEFAULT_FLUSH_INTERVAL_SECONDS = 5 diff --git a/litellm/integrations/custom_batch_logger.py b/litellm/integrations/custom_batch_logger.py index 7ef63d25c..292c836b3 100644 --- a/litellm/integrations/custom_batch_logger.py +++ b/litellm/integrations/custom_batch_logger.py @@ -8,20 +8,18 @@ import asyncio import time from typing import List, Literal, Optional +import litellm from litellm._logging import verbose_logger from litellm.integrations.custom_logger import CustomLogger -DEFAULT_BATCH_SIZE = 512 -DEFAULT_FLUSH_INTERVAL_SECONDS = 5 - class CustomBatchLogger(CustomLogger): def __init__( self, flush_lock: Optional[asyncio.Lock] = None, - batch_size: Optional[int] = DEFAULT_BATCH_SIZE, - flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS, + batch_size: Optional[int] = None, + flush_interval: Optional[int] = None, **kwargs, ) -> None: """ @@ -29,13 +27,12 @@ class CustomBatchLogger(CustomLogger): flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching """ self.log_queue: List = [] - self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS - self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE + self.flush_interval = flush_interval or litellm.DEFAULT_FLUSH_INTERVAL_SECONDS + self.batch_size: int = batch_size or litellm.DEFAULT_BATCH_SIZE self.last_flush_time = time.time() self.flush_lock = flush_lock super().__init__(**kwargs) - pass async def periodic_flush(self): while True: diff --git a/litellm/integrations/langsmith.py b/litellm/integrations/langsmith.py index 4abd2a2c3..f44ac5b87 100644 --- a/litellm/integrations/langsmith.py +++ b/litellm/integrations/langsmith.py @@ -68,8 +68,13 @@ class LangsmithLogger(CustomBatchLogger): if _batch_size: self.batch_size = int(_batch_size) self.log_queue: List[LangsmithQueueObject] = [] - asyncio.create_task(self.periodic_flush()) + loop = asyncio.get_event_loop_policy().get_event_loop() + if not loop.is_running(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.create_task(self.periodic_flush()) self.flush_lock = asyncio.Lock() + super().__init__(**kwargs, flush_lock=self.flush_lock) def get_credentials_from_env( @@ -122,7 +127,7 @@ class LangsmithLogger(CustomBatchLogger): "project_name", credentials["LANGSMITH_PROJECT"] ) run_name = metadata.get("run_name", self.langsmith_default_run_name) - run_id = metadata.get("id", None) + run_id = metadata.get("id", metadata.get("run_id", None)) parent_run_id = metadata.get("parent_run_id", None) trace_id = metadata.get("trace_id", None) session_id = metadata.get("session_id", None) @@ -173,14 +178,28 @@ class LangsmithLogger(CustomBatchLogger): if dotted_order: data["dotted_order"] = dotted_order + run_id: Optional[str] = data.get("id") # type: ignore if "id" not in data or data["id"] is None: """ for /batch langsmith requires id, trace_id and dotted_order passed as params """ run_id = str(uuid.uuid4()) - data["id"] = str(run_id) - data["trace_id"] = str(run_id) - data["dotted_order"] = self.make_dot_order(run_id=run_id) + + data["id"] = run_id + + if ( + "trace_id" not in data + or data["trace_id"] is None + and (run_id is not None and isinstance(run_id, str)) + ): + data["trace_id"] = run_id + + if ( + "dotted_order" not in data + or data["dotted_order"] is None + and (run_id is not None and isinstance(run_id, str)) + ): + data["dotted_order"] = self.make_dot_order(run_id=run_id) # type: ignore verbose_logger.debug("Langsmith Logging data on langsmith: %s", data) diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 483121c38..c1cc20a7d 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -437,29 +437,6 @@ class CustomStreamWrapper: except Exception: raise ValueError(f"Unable to parse response. Original response: {chunk}") - def handle_cohere_chat_chunk(self, chunk): - chunk = chunk.decode("utf-8") - data_json = json.loads(chunk) - print_verbose(f"chunk: {chunk}") - try: - text = "" - is_finished = False - finish_reason = "" - if "text" in data_json: - text = data_json["text"] - elif "is_finished" in data_json and data_json["is_finished"] is True: - is_finished = data_json["is_finished"] - finish_reason = data_json["finish_reason"] - else: - return - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } - except Exception: - raise ValueError(f"Unable to parse response. Original response: {chunk}") - def handle_azure_chunk(self, chunk): is_finished = False finish_reason = "" @@ -949,7 +926,12 @@ class CustomStreamWrapper: "function_call" in completion_obj and completion_obj["function_call"] is not None ) + or ( + "provider_specific_fields" in response_obj + and response_obj["provider_specific_fields"] is not None + ) ): # cannot set content of an OpenAI Object to be an empty string + self.safety_checker() hold, model_response_str = self.check_special_tokens( chunk=completion_obj["content"], @@ -1058,6 +1040,7 @@ class CustomStreamWrapper: and model_response.choices[0].delta.audio is not None ): return model_response + else: if hasattr(model_response, "usage"): self.chunks.append(model_response) @@ -1066,6 +1049,7 @@ class CustomStreamWrapper: def chunk_creator(self, chunk): # type: ignore # noqa: PLR0915 model_response = self.model_response_creator() response_obj: dict = {} + try: # return this for all models completion_obj = {"content": ""} @@ -1256,14 +1240,6 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] - elif self.custom_llm_provider == "cohere_chat": - response_obj = self.handle_cohere_chat_chunk(chunk) - if response_obj is None: - return - completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - self.received_finish_reason = response_obj["finish_reason"] - elif self.custom_llm_provider == "petals": if len(self.completion_stream) == 0: if self.received_finish_reason is not None: diff --git a/litellm/llms/cohere/chat.py b/litellm/llms/cohere/chat.py index e0a92b6c8..d7dfc3eaa 100644 --- a/litellm/llms/cohere/chat.py +++ b/litellm/llms/cohere/chat.py @@ -4,13 +4,20 @@ import time import traceback import types from enum import Enum -from typing import Callable, Optional +from typing import Any, Callable, List, Optional, Tuple, Union import httpx # type: ignore import requests # type: ignore import litellm +from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.types.llms.cohere import ToolResultObject +from litellm.types.utils import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, + GenericStreamingChunk, +) from litellm.utils import Choices, Message, ModelResponse, Usage from ..prompt_templates.factory import cohere_message_pt, cohere_messages_pt_v2 @@ -198,7 +205,107 @@ def construct_cohere_tool(tools=None): return cohere_tools -def completion( +async def make_call( + client: Optional[AsyncHTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, + timeout: Optional[Union[float, httpx.Timeout]], + json_mode: bool, +) -> Tuple[Any, httpx.Headers]: + if client is None: + client = litellm.module_level_aclient + + try: + response = await client.post( + api_base, headers=headers, data=data, stream=True, timeout=timeout + ) + except httpx.HTTPStatusError as e: + error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise CohereError( + status_code=e.response.status_code, + message=await e.response.aread(), + ) + except Exception as e: + for exception in litellm.LITELLM_EXCEPTION_TYPES: + if isinstance(e, exception): + raise e + raise CohereError(status_code=500, message=str(e)) + + completion_stream = ModelResponseIterator( + streaming_response=response.aiter_lines(), + sync_stream=False, + json_mode=json_mode, + ) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=completion_stream, # Pass the completion stream for logging + additional_args={"complete_input_dict": data}, + ) + + return completion_stream, response.headers + + +def make_sync_call( + client: Optional[HTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, + timeout: Optional[Union[float, httpx.Timeout]], +) -> Tuple[Any, httpx.Headers]: + if client is None: + client = litellm.module_level_client # re-use a module level client + + try: + response = client.post( + api_base, headers=headers, data=data, stream=True, timeout=timeout + ) + except httpx.HTTPStatusError as e: + raise CohereError( + status_code=e.response.status_code, + message=e.response.read(), + ) + except Exception as e: + for exception in litellm.LITELLM_EXCEPTION_TYPES: + if isinstance(e, exception): + raise e + raise CohereError(status_code=500, message=str(e)) + + if response.status_code != 200: + + raise CohereError( + status_code=response.status_code, + message=response.read(), + ) + + completion_stream = ModelResponseIterator( + streaming_response=response.iter_lines(), sync_stream=True + ) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream, response.headers + + +def completion( # noqa: PLR0915 model: str, messages: list, api_base: str, @@ -211,6 +318,8 @@ def completion( logging_obj, litellm_params=None, logger_fn=None, + client=None, + timeout=None, ): headers = validate_environment(api_key, headers=headers) completion_url = api_base @@ -269,7 +378,23 @@ def completion( raise CohereError(message=response.text, status_code=response.status_code) if "stream" in optional_params and optional_params["stream"] is True: - return response.iter_lines() + completion_stream, cohere_headers = make_sync_call( + client=client, + api_base=api_base, + headers=headers, # type: ignore + data=json.dumps(data), + model=model, + messages=messages, + logging_obj=logging_obj, + timeout=timeout, + ) + return CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="cohere_chat", + logging_obj=logging_obj, + _response_headers=dict(cohere_headers), + ) else: ## LOGGING logging_obj.post_call( @@ -286,6 +411,10 @@ def completion( except Exception: raise CohereError(message=response.text, status_code=response.status_code) + ## ADD CITATIONS + if "citations" in completion_response: + setattr(model_response, "citations", completion_response["citations"]) + ## Tool calling response cohere_tools_response = completion_response.get("tool_calls", None) if cohere_tools_response is not None and cohere_tools_response != []: @@ -325,3 +454,103 @@ def completion( ) setattr(model_response, "usage", usage) return model_response + + +class ModelResponseIterator: + def __init__( + self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False + ): + self.streaming_response = streaming_response + self.response_iterator = self.streaming_response + self.content_blocks: List = [] + self.tool_index = -1 + self.json_mode = json_mode + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + provider_specific_fields = None + + index = int(chunk.get("index", 0)) + + if "text" in chunk: + text = chunk["text"] + elif "is_finished" in chunk and chunk["is_finished"] is True: + is_finished = chunk["is_finished"] + finish_reason = chunk["finish_reason"] + + if "citations" in chunk: + provider_specific_fields = {"citations": chunk["citations"]} + + returned_chunk = GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=index, + provider_specific_fields=provider_specific_fields, + ) + + return returned_chunk + + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + + # Sync iterator + def __iter__(self): + return self + + def __next__(self): + try: + chunk = self.response_iterator.__next__() + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + data_json = json.loads(str_line) + return self.chunk_parser(chunk=data_json) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() + return self + + async def __anext__(self): + try: + chunk = await self.async_response_iterator.__anext__() + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + + data_json = json.loads(str_line) + return self.chunk_parser(chunk=data_json) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") diff --git a/litellm/main.py b/litellm/main.py index 5095ce518..6da7bb604 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1970,15 +1970,16 @@ def completion( # type: ignore # noqa: PLR0915 logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements ) - if "stream" in optional_params and optional_params["stream"] is True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="cohere_chat", - logging_obj=logging, - ) - return response + # if "stream" in optional_params and optional_params["stream"] is True: + # # don't try to access stream object, + # response = CustomStreamWrapper( + # model_response, + # model, + # custom_llm_provider="cohere_chat", + # logging_obj=logging, + # _response_headers=headers, + # ) + # return response response = model_response elif custom_llm_provider == "maritalk": maritalk_key = ( diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 97ae3a54d..b9fae4d25 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -39,16 +39,4 @@ router_settings: redis_port: "os.environ/REDIS_PORT" litellm_settings: - cache: true - cache_params: - type: redis - host: "os.environ/REDIS_HOST" - port: "os.environ/REDIS_PORT" - namespace: "litellm.caching" - ttl: 600 -# key_generation_settings: -# team_key_generation: -# allowed_team_member_roles: ["admin"] -# required_params: ["tags"] # require team admins to set tags for cost-tracking when generating a team key -# personal_key_generation: # maps to 'Default Team' on UI -# allowed_user_roles: ["proxy_admin"] \ No newline at end of file + success_callback: ["langsmith"] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 156e9ed4e..c7cf90715 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.53.2" +version = "1.53.1" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.53.2" +version = "1.53.1" version_files = [ "pyproject.toml:^version" ] diff --git a/tests/llm_translation/test_cohere.py b/tests/llm_translation/test_cohere.py new file mode 100644 index 000000000..b8cf8d4d1 --- /dev/null +++ b/tests/llm_translation/test_cohere.py @@ -0,0 +1,59 @@ +import os +import sys +import traceback + +from dotenv import load_dotenv + +load_dotenv() +import io +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import json + +import pytest + +import litellm +from litellm import RateLimitError, Timeout, completion, completion_cost, embedding + +litellm.num_retries = 3 + + +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.asyncio +async def test_chat_completion_cohere_citations(stream): + try: + litellm.set_verbose = True + messages = [ + { + "role": "user", + "content": "Which penguins are the tallest?", + }, + ] + response = await litellm.acompletion( + model="cohere_chat/command-r", + messages=messages, + documents=[ + {"title": "Tall penguins", "text": "Emperor penguins are the tallest."}, + { + "title": "Penguin habitats", + "text": "Emperor penguins only live in Antarctica.", + }, + ], + stream=stream, + ) + + if stream: + citations_chunk = False + async for chunk in response: + print("received chunk", chunk) + if "citations" in chunk: + citations_chunk = True + break + assert citations_chunk + else: + assert response.citations is not None + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/tests/local_testing/test_cohere_completion.py b/tests/local_testing/test_cohere_completion.py deleted file mode 100644 index e90818fee..000000000 --- a/tests/local_testing/test_cohere_completion.py +++ /dev/null @@ -1,210 +0,0 @@ -import os -import sys -import traceback - -from dotenv import load_dotenv - -load_dotenv() -import io -import os - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import json - -import pytest - -import litellm -from litellm import RateLimitError, Timeout, completion, completion_cost, embedding - -litellm.num_retries = 3 - - -# FYI - cohere_chat looks quite unstable, even when testing locally -def test_chat_completion_cohere(): - try: - litellm.set_verbose = True - messages = [ - { - "role": "user", - "content": "Hey", - }, - ] - response = completion( - model="cohere_chat/command-r", - messages=messages, - max_tokens=10, - ) - print(response) - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - -def test_chat_completion_cohere_tool_calling(): - try: - litellm.set_verbose = True - messages = [ - { - "role": "user", - "content": "What is the weather like in Boston?", - }, - ] - response = completion( - model="cohere_chat/command-r", - messages=messages, - tools=[ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - }, - }, - "required": ["location"], - }, - }, - } - ], - ) - print(response) - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - # def get_current_weather(location, unit="fahrenheit"): - # """Get the current weather in a given location""" - # if "tokyo" in location.lower(): - # return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit}) - # elif "san francisco" in location.lower(): - # return json.dumps({"location": "San Francisco", "temperature": "72", "unit": unit}) - # elif "paris" in location.lower(): - # return json.dumps({"location": "Paris", "temperature": "22", "unit": unit}) - # else: - # return json.dumps({"location": location, "temperature": "unknown"}) - - # def test_chat_completion_cohere_tool_with_result_calling(): - # # end to end cohere command-r with tool calling - # # Step 1 - Send available tools - # # Step 2 - Execute results - # # Step 3 - Send results to command-r - # try: - # litellm.set_verbose = True - # import json - - # # Step 1 - Send available tools - # tools = [ - # { - # "type": "function", - # "function": { - # "name": "get_current_weather", - # "description": "Get the current weather in a given location", - # "parameters": { - # "type": "object", - # "properties": { - # "location": { - # "type": "string", - # "description": "The city and state, e.g. San Francisco, CA", - # }, - # "unit": { - # "type": "string", - # "enum": ["celsius", "fahrenheit"], - # }, - # }, - # "required": ["location"], - # }, - # }, - # } - # ] - - # messages = [ - # { - # "role": "user", - # "content": "What is the weather like in Boston?", - # }, - # ] - # response = completion( - # model="cohere_chat/command-r", - # messages=messages, - # tools=tools, - # ) - # print("Response with tools to call", response) - # print(response) - - # # step 2 - Execute results - # tool_calls = response.tool_calls - - # available_functions = { - # "get_current_weather": get_current_weather, - # } # only one function in this example, but you can have multiple - - # for tool_call in tool_calls: - # function_name = tool_call.function.name - # function_to_call = available_functions[function_name] - # function_args = json.loads(tool_call.function.arguments) - # function_response = function_to_call( - # location=function_args.get("location"), - # unit=function_args.get("unit"), - # ) - # messages.append( - # { - # "tool_call_id": tool_call.id, - # "role": "tool", - # "name": function_name, - # "content": function_response, - # } - # ) # extend conversation with function response - - # print("messages with tool call results", messages) - - # messages = [ - # { - # "role": "user", - # "content": "What is the weather like in Boston?", - # }, - # { - # "tool_call_id": "tool_1", - # "role": "tool", - # "name": "get_current_weather", - # "content": {"location": "San Francisco, CA", "unit": "fahrenheit", "temperature": "72"}, - # }, - # ] - # respone = completion( - # model="cohere_chat/command-r", - # messages=messages, - # tools=[ - # { - # "type": "function", - # "function": { - # "name": "get_current_weather", - # "description": "Get the current weather in a given location", - # "parameters": { - # "type": "object", - # "properties": { - # "location": { - # "type": "string", - # "description": "The city and state, e.g. San Francisco, CA", - # }, - # "unit": { - # "type": "string", - # "enum": ["celsius", "fahrenheit"], - # }, - # }, - # "required": ["location"], - # }, - # }, - # } - # ], - # ) - # print(respone) - except Exception as e: - pytest.fail(f"Error occurred: {e}") diff --git a/tests/local_testing/test_function_calling.py b/tests/local_testing/test_function_calling.py index 6e1bd13a1..52a1fc697 100644 --- a/tests/local_testing/test_function_calling.py +++ b/tests/local_testing/test_function_calling.py @@ -46,11 +46,12 @@ def get_current_weather(location, unit="fahrenheit"): "model", [ "gpt-3.5-turbo-1106", - # "mistral/mistral-large-latest", + "mistral/mistral-large-latest", "claude-3-haiku-20240307", "gemini/gemini-1.5-pro", "anthropic.claude-3-sonnet-20240229-v1:0", - # "groq/llama3-8b-8192", + "groq/llama3-8b-8192", + "cohere_chat/command-r", ], ) @pytest.mark.flaky(retries=3, delay=1) diff --git a/tests/local_testing/test_langsmith.py b/tests/local_testing/test_langsmith.py index ab387e444..ff4c677d1 100644 --- a/tests/local_testing/test_langsmith.py +++ b/tests/local_testing/test_langsmith.py @@ -53,10 +53,17 @@ def test_async_langsmith_logging_with_metadata(): @pytest.mark.asyncio async def test_async_langsmith_logging_with_streaming_and_metadata(sync_mode): try: + litellm.DEFAULT_BATCH_SIZE = 1 + litellm.DEFAULT_FLUSH_INTERVAL_SECONDS = 1 test_langsmith_logger = LangsmithLogger() litellm.success_callback = ["langsmith"] litellm.set_verbose = True - run_id = str(uuid.uuid4()) + run_id = "497f6eca-6276-4993-bfeb-53cbbbba6f08" + run_name = "litellmRUN" + test_metadata = { + "run_name": run_name, # langsmith run name + "run_id": run_id, # langsmith run id + } messages = [{"role": "user", "content": "what llm are u"}] if sync_mode is True: @@ -66,7 +73,7 @@ async def test_async_langsmith_logging_with_streaming_and_metadata(sync_mode): max_tokens=10, temperature=0.2, stream=True, - metadata={"id": run_id}, + metadata=test_metadata, ) for cb in litellm.callbacks: if isinstance(cb, LangsmithLogger): @@ -82,7 +89,7 @@ async def test_async_langsmith_logging_with_streaming_and_metadata(sync_mode): temperature=0.2, mock_response="This is a mock request", stream=True, - metadata={"id": run_id}, + metadata=test_metadata, ) for cb in litellm.callbacks: if isinstance(cb, LangsmithLogger): @@ -100,11 +107,16 @@ async def test_async_langsmith_logging_with_streaming_and_metadata(sync_mode): input_fields_on_langsmith = logged_run_on_langsmith.get("inputs") - extra_fields_on_langsmith = logged_run_on_langsmith.get("extra").get( + extra_fields_on_langsmith = logged_run_on_langsmith.get("extra", {}).get( "invocation_params" ) - assert logged_run_on_langsmith.get("run_type") == "llm" + assert ( + logged_run_on_langsmith.get("run_type") == "llm" + ), f"run_type should be llm. Got: {logged_run_on_langsmith.get('run_type')}" + assert ( + logged_run_on_langsmith.get("name") == run_name + ), f"run_type should be llm. Got: {logged_run_on_langsmith.get('run_type')}" print("\nLogged INPUT ON LANGSMITH", input_fields_on_langsmith) print("\nextra fields on langsmith", extra_fields_on_langsmith)