Compare commits

..

6 commits

Author SHA1 Message Date
Krish Dholakia
4b9c66ea59 LiteLLM Minor Fixes & Improvements (11/29/2024) (#6965)
* fix(factory.py): ensure tool call converts image url

Fixes https://github.com/BerriAI/litellm/issues/6953

* fix(transformation.py): support mp4 + pdf url's for vertex ai

Fixes https://github.com/BerriAI/litellm/issues/6936

* fix(http_handler.py): mask gemini api key in error logs

Fixes https://github.com/BerriAI/litellm/issues/6963

* docs(prometheus.md): update prometheus FAQs

* feat(auth_checks.py): ensure specific model access > wildcard model access

if wildcard model is in access group, but specific model is not - deny access

* fix(auth_checks.py): handle auth checks for team based model access groups

handles scenario where model access group used for wildcard models

* fix(internal_user_endpoints.py): support adding guardrails on `/user/update`

Fixes https://github.com/BerriAI/litellm/issues/6942

* fix(key_management_endpoints.py): fix prepare_metadata_fields helper

* fix: fix tests

* build(requirements.txt): bump openai dep version

fixes proxies argument

* test: fix tests

* fix(http_handler.py): fix error message masking

* fix(bedrock_guardrails.py): pass in prepped data

* test: fix test

* test: fix nvidia nim test

* fix(http_handler.py): return original response headers

* fix: revert maskedhttpstatuserror

* test: update tests

* test: cleanup test

* fix(key_management_endpoints.py): fix metadata field update logic

* fix(key_management_endpoints.py): maintain initial order of guardrails in key update

* fix(key_management_endpoints.py): handle prepare metadata

* fix: fix linting errors

* fix: fix linting errors

* fix: fix linting errors

* fix: fix key management errors

* fix(key_management_endpoints.py): update metadata

* test: update test

* refactor: add more debug statements

* test: skip flaky test

* test: fix test

* fix: fix test

* fix: fix update metadata logic

* fix: fix test

* ci(config.yml): change db url for e2e ui testing
2024-12-01 05:26:06 -08:00
Krrish Dholakia
afb892c6d0 fix: suppress linting error 2024-11-30 17:26:49 -08:00
Krrish Dholakia
b9585d2016 fix(langsmith.py): fix langsmith quickstart
Fixes https://github.com/BerriAI/litellm/issues/6861
2024-11-30 17:24:39 -08:00
Krrish Dholakia
147dfa61b0 fix(langsmith.py): support 'run_id' for langsmith
Fixes https://github.com/BerriAI/litellm/issues/6862
2024-11-30 16:45:23 -08:00
Krrish Dholakia
927f9fa4eb fix(cohere/chat.py): fix linting errors 2024-11-30 16:01:04 -08:00
Krrish Dholakia
2fbc71a62c feat(cohere/chat.py): return citations in model response
Closes https://github.com/BerriAI/litellm/issues/6814
2024-11-30 13:59:57 -08:00
13 changed files with 367 additions and 289 deletions

View file

@ -17,7 +17,11 @@ from litellm._logging import (
_turn_on_json, _turn_on_json,
log_level, log_level,
) )
from litellm.constants import ROUTER_MAX_FALLBACKS from litellm.constants import (
DEFAULT_BATCH_SIZE,
DEFAULT_FLUSH_INTERVAL_SECONDS,
ROUTER_MAX_FALLBACKS,
)
from litellm.types.guardrails import GuardrailItem from litellm.types.guardrails import GuardrailItem
from litellm.proxy._types import ( from litellm.proxy._types import (
KeyManagementSystem, KeyManagementSystem,

View file

@ -1 +1,3 @@
ROUTER_MAX_FALLBACKS = 5 ROUTER_MAX_FALLBACKS = 5
DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5

View file

@ -8,20 +8,18 @@ import asyncio
import time import time
from typing import List, Literal, Optional from typing import List, Literal, Optional
import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
class CustomBatchLogger(CustomLogger): class CustomBatchLogger(CustomLogger):
def __init__( def __init__(
self, self,
flush_lock: Optional[asyncio.Lock] = None, flush_lock: Optional[asyncio.Lock] = None,
batch_size: Optional[int] = DEFAULT_BATCH_SIZE, batch_size: Optional[int] = None,
flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS, flush_interval: Optional[int] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
@ -29,13 +27,12 @@ class CustomBatchLogger(CustomLogger):
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
""" """
self.log_queue: List = [] self.log_queue: List = []
self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS self.flush_interval = flush_interval or litellm.DEFAULT_FLUSH_INTERVAL_SECONDS
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE self.batch_size: int = batch_size or litellm.DEFAULT_BATCH_SIZE
self.last_flush_time = time.time() self.last_flush_time = time.time()
self.flush_lock = flush_lock self.flush_lock = flush_lock
super().__init__(**kwargs) super().__init__(**kwargs)
pass
async def periodic_flush(self): async def periodic_flush(self):
while True: while True:

View file

@ -68,8 +68,13 @@ class LangsmithLogger(CustomBatchLogger):
if _batch_size: if _batch_size:
self.batch_size = int(_batch_size) self.batch_size = int(_batch_size)
self.log_queue: List[LangsmithQueueObject] = [] self.log_queue: List[LangsmithQueueObject] = []
asyncio.create_task(self.periodic_flush()) loop = asyncio.get_event_loop_policy().get_event_loop()
if not loop.is_running():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock() self.flush_lock = asyncio.Lock()
super().__init__(**kwargs, flush_lock=self.flush_lock) super().__init__(**kwargs, flush_lock=self.flush_lock)
def get_credentials_from_env( def get_credentials_from_env(
@ -122,7 +127,7 @@ class LangsmithLogger(CustomBatchLogger):
"project_name", credentials["LANGSMITH_PROJECT"] "project_name", credentials["LANGSMITH_PROJECT"]
) )
run_name = metadata.get("run_name", self.langsmith_default_run_name) run_name = metadata.get("run_name", self.langsmith_default_run_name)
run_id = metadata.get("id", None) run_id = metadata.get("id", metadata.get("run_id", None))
parent_run_id = metadata.get("parent_run_id", None) parent_run_id = metadata.get("parent_run_id", None)
trace_id = metadata.get("trace_id", None) trace_id = metadata.get("trace_id", None)
session_id = metadata.get("session_id", None) session_id = metadata.get("session_id", None)
@ -173,14 +178,28 @@ class LangsmithLogger(CustomBatchLogger):
if dotted_order: if dotted_order:
data["dotted_order"] = dotted_order data["dotted_order"] = dotted_order
run_id: Optional[str] = data.get("id") # type: ignore
if "id" not in data or data["id"] is None: if "id" not in data or data["id"] is None:
""" """
for /batch langsmith requires id, trace_id and dotted_order passed as params for /batch langsmith requires id, trace_id and dotted_order passed as params
""" """
run_id = str(uuid.uuid4()) run_id = str(uuid.uuid4())
data["id"] = str(run_id)
data["trace_id"] = str(run_id) data["id"] = run_id
data["dotted_order"] = self.make_dot_order(run_id=run_id)
if (
"trace_id" not in data
or data["trace_id"] is None
and (run_id is not None and isinstance(run_id, str))
):
data["trace_id"] = run_id
if (
"dotted_order" not in data
or data["dotted_order"] is None
and (run_id is not None and isinstance(run_id, str))
):
data["dotted_order"] = self.make_dot_order(run_id=run_id) # type: ignore
verbose_logger.debug("Langsmith Logging data on langsmith: %s", data) verbose_logger.debug("Langsmith Logging data on langsmith: %s", data)

View file

@ -437,29 +437,6 @@ class CustomStreamWrapper:
except Exception: except Exception:
raise ValueError(f"Unable to parse response. Original response: {chunk}") raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_cohere_chat_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
print_verbose(f"chunk: {chunk}")
try:
text = ""
is_finished = False
finish_reason = ""
if "text" in data_json:
text = data_json["text"]
elif "is_finished" in data_json and data_json["is_finished"] is True:
is_finished = data_json["is_finished"]
finish_reason = data_json["finish_reason"]
else:
return
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
except Exception:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_azure_chunk(self, chunk): def handle_azure_chunk(self, chunk):
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
@ -949,7 +926,12 @@ class CustomStreamWrapper:
"function_call" in completion_obj "function_call" in completion_obj
and completion_obj["function_call"] is not None and completion_obj["function_call"] is not None
) )
or (
"provider_specific_fields" in response_obj
and response_obj["provider_specific_fields"] is not None
)
): # cannot set content of an OpenAI Object to be an empty string ): # cannot set content of an OpenAI Object to be an empty string
self.safety_checker() self.safety_checker()
hold, model_response_str = self.check_special_tokens( hold, model_response_str = self.check_special_tokens(
chunk=completion_obj["content"], chunk=completion_obj["content"],
@ -1058,6 +1040,7 @@ class CustomStreamWrapper:
and model_response.choices[0].delta.audio is not None and model_response.choices[0].delta.audio is not None
): ):
return model_response return model_response
else: else:
if hasattr(model_response, "usage"): if hasattr(model_response, "usage"):
self.chunks.append(model_response) self.chunks.append(model_response)
@ -1066,6 +1049,7 @@ class CustomStreamWrapper:
def chunk_creator(self, chunk): # type: ignore # noqa: PLR0915 def chunk_creator(self, chunk): # type: ignore # noqa: PLR0915
model_response = self.model_response_creator() model_response = self.model_response_creator()
response_obj: dict = {} response_obj: dict = {}
try: try:
# return this for all models # return this for all models
completion_obj = {"content": ""} completion_obj = {"content": ""}
@ -1256,14 +1240,6 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "cohere_chat":
response_obj = self.handle_cohere_chat_chunk(chunk)
if response_obj is None:
return
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "petals": elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0: if len(self.completion_stream) == 0:
if self.received_finish_reason is not None: if self.received_finish_reason is not None:

View file

@ -4,13 +4,20 @@ import time
import traceback import traceback
import types import types
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Any, Callable, List, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.cohere import ToolResultObject from litellm.types.llms.cohere import ToolResultObject
from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
GenericStreamingChunk,
)
from litellm.utils import Choices, Message, ModelResponse, Usage from litellm.utils import Choices, Message, ModelResponse, Usage
from ..prompt_templates.factory import cohere_message_pt, cohere_messages_pt_v2 from ..prompt_templates.factory import cohere_message_pt, cohere_messages_pt_v2
@ -198,7 +205,107 @@ def construct_cohere_tool(tools=None):
return cohere_tools return cohere_tools
def completion( async def make_call(
client: Optional[AsyncHTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_aclient
try:
response = await client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout
)
except httpx.HTTPStatusError as e:
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise CohereError(
status_code=e.response.status_code,
message=await e.response.aread(),
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise CohereError(status_code=500, message=str(e))
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(),
sync_stream=False,
json_mode=json_mode,
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream, response.headers
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_client # re-use a module level client
try:
response = client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout
)
except httpx.HTTPStatusError as e:
raise CohereError(
status_code=e.response.status_code,
message=e.response.read(),
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise CohereError(status_code=500, message=str(e))
if response.status_code != 200:
raise CohereError(
status_code=response.status_code,
message=response.read(),
)
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream, response.headers
def completion( # noqa: PLR0915
model: str, model: str,
messages: list, messages: list,
api_base: str, api_base: str,
@ -211,6 +318,8 @@ def completion(
logging_obj, logging_obj,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
client=None,
timeout=None,
): ):
headers = validate_environment(api_key, headers=headers) headers = validate_environment(api_key, headers=headers)
completion_url = api_base completion_url = api_base
@ -269,7 +378,23 @@ def completion(
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] is True: if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines() completion_stream, cohere_headers = make_sync_call(
client=client,
api_base=api_base,
headers=headers, # type: ignore
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cohere_chat",
logging_obj=logging_obj,
_response_headers=dict(cohere_headers),
)
else: else:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -286,6 +411,10 @@ def completion(
except Exception: except Exception:
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
## ADD CITATIONS
if "citations" in completion_response:
setattr(model_response, "citations", completion_response["citations"])
## Tool calling response ## Tool calling response
cohere_tools_response = completion_response.get("tool_calls", None) cohere_tools_response = completion_response.get("tool_calls", None)
if cohere_tools_response is not None and cohere_tools_response != []: if cohere_tools_response is not None and cohere_tools_response != []:
@ -325,3 +454,103 @@ def completion(
) )
setattr(model_response, "usage", usage) setattr(model_response, "usage", usage)
return model_response return model_response
class ModelResponseIterator:
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List = []
self.tool_index = -1
self.json_mode = json_mode
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
if "text" in chunk:
text = chunk["text"]
elif "is_finished" in chunk and chunk["is_finished"] is True:
is_finished = chunk["is_finished"]
finish_reason = chunk["finish_reason"]
if "citations" in chunk:
provider_specific_fields = {"citations": chunk["citations"]}
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -1970,15 +1970,16 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
) )
if "stream" in optional_params and optional_params["stream"] is True: # if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object, # # don't try to access stream object,
response = CustomStreamWrapper( # response = CustomStreamWrapper(
model_response, # model_response,
model, # model,
custom_llm_provider="cohere_chat", # custom_llm_provider="cohere_chat",
logging_obj=logging, # logging_obj=logging,
) # _response_headers=headers,
return response # )
# return response
response = model_response response = model_response
elif custom_llm_provider == "maritalk": elif custom_llm_provider == "maritalk":
maritalk_key = ( maritalk_key = (

View file

@ -39,16 +39,4 @@ router_settings:
redis_port: "os.environ/REDIS_PORT" redis_port: "os.environ/REDIS_PORT"
litellm_settings: litellm_settings:
cache: true success_callback: ["langsmith"]
cache_params:
type: redis
host: "os.environ/REDIS_HOST"
port: "os.environ/REDIS_PORT"
namespace: "litellm.caching"
ttl: 600
# key_generation_settings:
# team_key_generation:
# allowed_team_member_roles: ["admin"]
# required_params: ["tags"] # require team admins to set tags for cost-tracking when generating a team key
# personal_key_generation: # maps to 'Default Team' on UI
# allowed_user_roles: ["proxy_admin"]

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.53.2" version = "1.53.1"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.53.2" version = "1.53.1"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]

View file

@ -0,0 +1,59 @@
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import json
import pytest
import litellm
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
litellm.num_retries = 3
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.asyncio
async def test_chat_completion_cohere_citations(stream):
try:
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "Which penguins are the tallest?",
},
]
response = await litellm.acompletion(
model="cohere_chat/command-r",
messages=messages,
documents=[
{"title": "Tall penguins", "text": "Emperor penguins are the tallest."},
{
"title": "Penguin habitats",
"text": "Emperor penguins only live in Antarctica.",
},
],
stream=stream,
)
if stream:
citations_chunk = False
async for chunk in response:
print("received chunk", chunk)
if "citations" in chunk:
citations_chunk = True
break
assert citations_chunk
else:
assert response.citations is not None
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -1,210 +0,0 @@
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import json
import pytest
import litellm
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
litellm.num_retries = 3
# FYI - cohere_chat looks quite unstable, even when testing locally
def test_chat_completion_cohere():
try:
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="cohere_chat/command-r",
messages=messages,
max_tokens=10,
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_chat_completion_cohere_tool_calling():
try:
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "What is the weather like in Boston?",
},
]
response = completion(
model="cohere_chat/command-r",
messages=messages,
tools=[
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
],
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# def get_current_weather(location, unit="fahrenheit"):
# """Get the current weather in a given location"""
# if "tokyo" in location.lower():
# return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit})
# elif "san francisco" in location.lower():
# return json.dumps({"location": "San Francisco", "temperature": "72", "unit": unit})
# elif "paris" in location.lower():
# return json.dumps({"location": "Paris", "temperature": "22", "unit": unit})
# else:
# return json.dumps({"location": location, "temperature": "unknown"})
# def test_chat_completion_cohere_tool_with_result_calling():
# # end to end cohere command-r with tool calling
# # Step 1 - Send available tools
# # Step 2 - Execute results
# # Step 3 - Send results to command-r
# try:
# litellm.set_verbose = True
# import json
# # Step 1 - Send available tools
# tools = [
# {
# "type": "function",
# "function": {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA",
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"],
# },
# },
# "required": ["location"],
# },
# },
# }
# ]
# messages = [
# {
# "role": "user",
# "content": "What is the weather like in Boston?",
# },
# ]
# response = completion(
# model="cohere_chat/command-r",
# messages=messages,
# tools=tools,
# )
# print("Response with tools to call", response)
# print(response)
# # step 2 - Execute results
# tool_calls = response.tool_calls
# available_functions = {
# "get_current_weather": get_current_weather,
# } # only one function in this example, but you can have multiple
# for tool_call in tool_calls:
# function_name = tool_call.function.name
# function_to_call = available_functions[function_name]
# function_args = json.loads(tool_call.function.arguments)
# function_response = function_to_call(
# location=function_args.get("location"),
# unit=function_args.get("unit"),
# )
# messages.append(
# {
# "tool_call_id": tool_call.id,
# "role": "tool",
# "name": function_name,
# "content": function_response,
# }
# ) # extend conversation with function response
# print("messages with tool call results", messages)
# messages = [
# {
# "role": "user",
# "content": "What is the weather like in Boston?",
# },
# {
# "tool_call_id": "tool_1",
# "role": "tool",
# "name": "get_current_weather",
# "content": {"location": "San Francisco, CA", "unit": "fahrenheit", "temperature": "72"},
# },
# ]
# respone = completion(
# model="cohere_chat/command-r",
# messages=messages,
# tools=[
# {
# "type": "function",
# "function": {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA",
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"],
# },
# },
# "required": ["location"],
# },
# },
# }
# ],
# )
# print(respone)
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -46,11 +46,12 @@ def get_current_weather(location, unit="fahrenheit"):
"model", "model",
[ [
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-1106",
# "mistral/mistral-large-latest", "mistral/mistral-large-latest",
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
"gemini/gemini-1.5-pro", "gemini/gemini-1.5-pro",
"anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0",
# "groq/llama3-8b-8192", "groq/llama3-8b-8192",
"cohere_chat/command-r",
], ],
) )
@pytest.mark.flaky(retries=3, delay=1) @pytest.mark.flaky(retries=3, delay=1)

View file

@ -53,10 +53,17 @@ def test_async_langsmith_logging_with_metadata():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_langsmith_logging_with_streaming_and_metadata(sync_mode): async def test_async_langsmith_logging_with_streaming_and_metadata(sync_mode):
try: try:
litellm.DEFAULT_BATCH_SIZE = 1
litellm.DEFAULT_FLUSH_INTERVAL_SECONDS = 1
test_langsmith_logger = LangsmithLogger() test_langsmith_logger = LangsmithLogger()
litellm.success_callback = ["langsmith"] litellm.success_callback = ["langsmith"]
litellm.set_verbose = True litellm.set_verbose = True
run_id = str(uuid.uuid4()) run_id = "497f6eca-6276-4993-bfeb-53cbbbba6f08"
run_name = "litellmRUN"
test_metadata = {
"run_name": run_name, # langsmith run name
"run_id": run_id, # langsmith run id
}
messages = [{"role": "user", "content": "what llm are u"}] messages = [{"role": "user", "content": "what llm are u"}]
if sync_mode is True: if sync_mode is True:
@ -66,7 +73,7 @@ async def test_async_langsmith_logging_with_streaming_and_metadata(sync_mode):
max_tokens=10, max_tokens=10,
temperature=0.2, temperature=0.2,
stream=True, stream=True,
metadata={"id": run_id}, metadata=test_metadata,
) )
for cb in litellm.callbacks: for cb in litellm.callbacks:
if isinstance(cb, LangsmithLogger): if isinstance(cb, LangsmithLogger):
@ -82,7 +89,7 @@ async def test_async_langsmith_logging_with_streaming_and_metadata(sync_mode):
temperature=0.2, temperature=0.2,
mock_response="This is a mock request", mock_response="This is a mock request",
stream=True, stream=True,
metadata={"id": run_id}, metadata=test_metadata,
) )
for cb in litellm.callbacks: for cb in litellm.callbacks:
if isinstance(cb, LangsmithLogger): if isinstance(cb, LangsmithLogger):
@ -100,11 +107,16 @@ async def test_async_langsmith_logging_with_streaming_and_metadata(sync_mode):
input_fields_on_langsmith = logged_run_on_langsmith.get("inputs") input_fields_on_langsmith = logged_run_on_langsmith.get("inputs")
extra_fields_on_langsmith = logged_run_on_langsmith.get("extra").get( extra_fields_on_langsmith = logged_run_on_langsmith.get("extra", {}).get(
"invocation_params" "invocation_params"
) )
assert logged_run_on_langsmith.get("run_type") == "llm" assert (
logged_run_on_langsmith.get("run_type") == "llm"
), f"run_type should be llm. Got: {logged_run_on_langsmith.get('run_type')}"
assert (
logged_run_on_langsmith.get("name") == run_name
), f"run_type should be llm. Got: {logged_run_on_langsmith.get('run_type')}"
print("\nLogged INPUT ON LANGSMITH", input_fields_on_langsmith) print("\nLogged INPUT ON LANGSMITH", input_fields_on_langsmith)
print("\nextra fields on langsmith", extra_fields_on_langsmith) print("\nextra fields on langsmith", extra_fields_on_langsmith)