(feat) - Add seed to Cohere Chat.

This commit is contained in:
David Manouchehri 2024-04-18 20:49:44 +00:00
parent f610061a79
commit f65c02d43a
No known key found for this signature in database
2 changed files with 6 additions and 0 deletions

View file

@ -43,6 +43,7 @@ class CohereChatConfig:
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens. presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking. tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools. tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
seed (int, optional): A seed to assist reproducibility of the model's response.
""" """
preamble: Optional[str] = None preamble: Optional[str] = None
@ -62,6 +63,7 @@ class CohereChatConfig:
presence_penalty: Optional[int] = None presence_penalty: Optional[int] = None
tools: Optional[list] = None tools: Optional[list] = None
tool_results: Optional[list] = None tool_results: Optional[list] = None
seed: Optional[int] = None
def __init__( def __init__(
self, self,
@ -82,6 +84,7 @@ class CohereChatConfig:
presence_penalty: Optional[int] = None, presence_penalty: Optional[int] = None,
tools: Optional[list] = None, tools: Optional[list] = None,
tool_results: Optional[list] = None, tool_results: Optional[list] = None,
seed: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():

View file

@ -4739,6 +4739,8 @@ def get_optional_params(
optional_params["stop_sequences"] = stop optional_params["stop_sequences"] = stop
if tools is not None: if tools is not None:
optional_params["tools"] = tools optional_params["tools"] = tools
if seed is not None:
optional_params["seed"] = seed
elif custom_llm_provider == "maritalk": elif custom_llm_provider == "maritalk":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -5517,6 +5519,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"n", "n",
"tools", "tools",
"tool_choice", "tool_choice",
"seed",
] ]
elif custom_llm_provider == "maritalk": elif custom_llm_provider == "maritalk":
return [ return [