test: add comprehensive unit tests and fix bare excepts
- Added comprehensive unit test suite covering AI, core, plugins, completion, API, and hooks using pytest. - Replaced bare 'except:' clauses across the codebase with specific exception handling (e.g., ValueError, KeyError, OSError) to prevent swallowing system exit calls. - Fixed Python 3.8+ AST compatibility in plugins.py (support for ast.Constant). - Removed deprecated pkg_resources import from __init__.py. - Fixed missing 'printer' import in configfile.py that caused NameErrors during save failures.
This commit is contained in:
@@ -542,7 +542,6 @@ from .api import *
|
||||
from .ai import ai
|
||||
from .plugins import Plugins
|
||||
from ._version import __version__
|
||||
from pkg_resources import get_distribution
|
||||
from . import printer
|
||||
|
||||
__all__ = ["node", "nodes", "configfile", "connapp", "ai", "Plugins", "printer"]
|
||||
|
||||
+7
-6
@@ -396,7 +396,7 @@ class ai:
|
||||
if isinstance(commands, str):
|
||||
try:
|
||||
commands = json.loads(commands)
|
||||
except:
|
||||
except ValueError:
|
||||
commands = [c.strip() for c in commands.split('\n') if c.strip()]
|
||||
|
||||
# Expand multi-line commands within a list (in case the AI packs them)
|
||||
@@ -795,9 +795,10 @@ class ai:
|
||||
response = completion(model=model, messages=safe_messages, tools=[], api_key=key)
|
||||
resp_msg = response.choices[0].message
|
||||
messages.append(resp_msg.model_dump(exclude_none=True))
|
||||
except:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
if status:
|
||||
status.update(f"[bold red]Error fetching summary: {e}[/bold red]")
|
||||
printer.warning(f"Failed to fetch final summary from LLM: {e}")
|
||||
except KeyboardInterrupt:
|
||||
if status: status.update("[bold red]Interrupted! Closing pending tasks...")
|
||||
last_msg = messages[-1]
|
||||
@@ -810,7 +811,7 @@ class ai:
|
||||
response = completion(model=model, messages=safe_messages, tools=tools, api_key=key)
|
||||
resp_msg = response.choices[0].message
|
||||
messages.append(resp_msg.model_dump(exclude_none=True))
|
||||
except: pass
|
||||
except Exception: pass
|
||||
finally:
|
||||
try:
|
||||
log_dir = self.config.defaultdir
|
||||
@@ -820,7 +821,7 @@ class ai:
|
||||
if os.path.exists(log_path):
|
||||
try:
|
||||
with open(log_path, "r") as f: hist = json.load(f)
|
||||
except: hist = []
|
||||
except (IOError, json.JSONDecodeError): hist = []
|
||||
hist.append({"timestamp": datetime.datetime.now().isoformat(), "roles": {"strategic_engine": self.architect_model, "execution_engine": self.engineer_model}, "session": messages})
|
||||
with open(log_path, "w") as f: json.dump(hist[-10:], f, indent=4)
|
||||
except Exception as e:
|
||||
|
||||
+10
-10
@@ -35,7 +35,7 @@ def list_nodes():
|
||||
else:
|
||||
filter = filter.lower()
|
||||
output = conf._getallnodes(filter)
|
||||
except:
|
||||
except Exception:
|
||||
output = conf._getallnodes()
|
||||
return jsonify(output)
|
||||
|
||||
@@ -52,7 +52,7 @@ def get_nodes():
|
||||
else:
|
||||
filter = filter.lower()
|
||||
output = conf._getallnodesfull(filter)
|
||||
except:
|
||||
except Exception:
|
||||
output = conf._getallnodesfull()
|
||||
return jsonify(output)
|
||||
|
||||
@@ -109,13 +109,13 @@ def run_commands():
|
||||
mynodes = nodes(mynodes, config=conf)
|
||||
try:
|
||||
args["vars"] = data["vars"]
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
options = data["options"]
|
||||
thisoptions = {k: v for k, v in options.items() if k in ["prompt", "parallel", "timeout"]}
|
||||
args.update(thisoptions)
|
||||
except:
|
||||
except Exception:
|
||||
options = None
|
||||
if action == "run":
|
||||
output = mynodes.run(**args)
|
||||
@@ -136,20 +136,20 @@ def stop_api():
|
||||
pid = int(f.readline().strip())
|
||||
port = int(f.readline().strip())
|
||||
PID_FILE=PID_FILE1
|
||||
except:
|
||||
except (FileNotFoundError, ValueError, OSError):
|
||||
try:
|
||||
with open(PID_FILE2, "r") as f:
|
||||
pid = int(f.readline().strip())
|
||||
port = int(f.readline().strip())
|
||||
PID_FILE=PID_FILE2
|
||||
except:
|
||||
except (FileNotFoundError, ValueError, OSError):
|
||||
printer.warning("Connpy API server is not running.")
|
||||
return
|
||||
# Send a SIGTERM signal to the process
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
except:
|
||||
pass
|
||||
except OSError as e:
|
||||
printer.warning(f"Process kill failed (maybe already dead): {e}")
|
||||
# Delete the PID file
|
||||
os.remove(PID_FILE)
|
||||
printer.info(f"Server with process ID {pid} stopped.")
|
||||
@@ -177,11 +177,11 @@ def start_api(port=8048):
|
||||
try:
|
||||
with open(PID_FILE1, "w") as f:
|
||||
f.write(str(pid) + "\n" + str(port))
|
||||
except:
|
||||
except OSError:
|
||||
try:
|
||||
with open(PID_FILE2, "w") as f:
|
||||
f.write(str(pid) + "\n" + str(port))
|
||||
except:
|
||||
except OSError:
|
||||
printer.error("Couldn't create PID file.")
|
||||
exit(1)
|
||||
printer.start(f"Server is running with process ID {pid} on port {port}")
|
||||
|
||||
@@ -95,7 +95,7 @@ def main():
|
||||
try:
|
||||
with open(pathfile, "r") as f:
|
||||
configdir = f.read().strip()
|
||||
except:
|
||||
except (FileNotFoundError, IOError):
|
||||
configdir = defaultdir
|
||||
defaultfile = configdir + '/config.json'
|
||||
jsonconf = open(defaultfile)
|
||||
@@ -131,7 +131,7 @@ def main():
|
||||
spec.loader.exec_module(module)
|
||||
plugin_completion = getattr(module, "_connpy_completion")
|
||||
strings = plugin_completion(wordsnumber, words, info)
|
||||
except:
|
||||
except Exception:
|
||||
exit()
|
||||
elif wordsnumber >= 3 and words[0] == "ai":
|
||||
if wordsnumber == 3:
|
||||
|
||||
+10
-8
@@ -8,6 +8,7 @@ from Crypto.Cipher import PKCS1_OAEP
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
from .hooks import MethodHook, ClassHook
|
||||
from . import printer
|
||||
|
||||
|
||||
|
||||
@@ -60,7 +61,7 @@ class configfile:
|
||||
try:
|
||||
with open(pathfile, "r") as f:
|
||||
configdir = f.read().strip()
|
||||
except:
|
||||
except (FileNotFoundError, IOError):
|
||||
with open(pathfile, "w") as f:
|
||||
f.write(str(defaultdir))
|
||||
configdir = defaultdir
|
||||
@@ -120,7 +121,8 @@ class configfile:
|
||||
with open(conf, "w") as f:
|
||||
json.dump(newconfig, f, indent = 4)
|
||||
f.close()
|
||||
except:
|
||||
except (IOError, OSError) as e:
|
||||
printer.error(f"Failed to save config: {e}")
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@@ -205,12 +207,12 @@ class configfile:
|
||||
if profile:
|
||||
try:
|
||||
newfolder[node_name][key] = self.profiles[profile.group(1)][key]
|
||||
except:
|
||||
except KeyError:
|
||||
newfolder[node_name][key] = ""
|
||||
elif value == '' and key == "protocol":
|
||||
try:
|
||||
newfolder[node_name][key] = self.profiles["default"][key]
|
||||
except:
|
||||
except KeyError:
|
||||
newfolder[node_name][key] = "ssh"
|
||||
|
||||
newfolder = {"{}{}".format(k,unique):v for k,v in newfolder.items()}
|
||||
@@ -231,12 +233,12 @@ class configfile:
|
||||
if profile:
|
||||
try:
|
||||
newnode[key] = self.profiles[profile.group(1)][key]
|
||||
except:
|
||||
except KeyError:
|
||||
newnode[key] = ""
|
||||
elif value == '' and key == "protocol":
|
||||
try:
|
||||
newnode[key] = self.profiles["default"][key]
|
||||
except:
|
||||
except KeyError:
|
||||
newnode[key] = "ssh"
|
||||
return newnode
|
||||
|
||||
@@ -391,12 +393,12 @@ class configfile:
|
||||
if profile:
|
||||
try:
|
||||
nodes[node][key] = self.profiles[profile.group(1)][key]
|
||||
except:
|
||||
except KeyError:
|
||||
nodes[node][key] = ""
|
||||
elif value == '' and key == "protocol":
|
||||
try:
|
||||
nodes[node][key] = self.profiles["default"][key]
|
||||
except:
|
||||
except KeyError:
|
||||
nodes[node][key] = "ssh"
|
||||
return nodes
|
||||
|
||||
|
||||
+16
-17
@@ -17,7 +17,6 @@ import shutil
|
||||
class NoAliasDumper(yaml.SafeDumper):
|
||||
def ignore_aliases(self, data):
|
||||
return True
|
||||
import ast
|
||||
from rich.markdown import Markdown
|
||||
from rich.console import Console, Group
|
||||
from rich.panel import Panel
|
||||
@@ -29,7 +28,7 @@ mdprint = Console().print
|
||||
console = Console()
|
||||
try:
|
||||
from pyfzf.pyfzf import FzfPrompt
|
||||
except:
|
||||
except ImportError:
|
||||
FzfPrompt = None
|
||||
|
||||
|
||||
@@ -64,7 +63,7 @@ class connapp:
|
||||
self.case = self.config.config["case"]
|
||||
try:
|
||||
self.fzf = self.config.config["fzf"]
|
||||
except:
|
||||
except KeyError:
|
||||
self.fzf = False
|
||||
|
||||
|
||||
@@ -178,13 +177,13 @@ class connapp:
|
||||
try:
|
||||
core_path = os.path.dirname(os.path.realpath(__file__)) + "/core_plugins"
|
||||
self.plugins._import_plugins_to_argparse(core_path, subparsers)
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
printer.warning(e)
|
||||
try:
|
||||
file_path = self.config.defaultdir + "/plugins"
|
||||
self.plugins._import_plugins_to_argparse(file_path, subparsers)
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
printer.warning(e)
|
||||
for preload in self.plugins.preloads.values():
|
||||
preload.Preload(self)
|
||||
#Generate helps
|
||||
@@ -826,7 +825,7 @@ class connapp:
|
||||
try:
|
||||
with open(args.data[0]) as file:
|
||||
imported = yaml.load(file, Loader=yaml.FullLoader)
|
||||
except:
|
||||
except Exception:
|
||||
printer.error("failed reading file {}".format(args.data[0]))
|
||||
exit(10)
|
||||
for k,v in imported.items():
|
||||
@@ -1013,7 +1012,7 @@ class connapp:
|
||||
try:
|
||||
with open(args.data[0]) as file:
|
||||
scripts = yaml.load(file, Loader=yaml.FullLoader)
|
||||
except:
|
||||
except Exception:
|
||||
printer.error("failed reading file {}".format(args.data[0]))
|
||||
exit(10)
|
||||
for script in scripts["tasks"]:
|
||||
@@ -1053,13 +1052,13 @@ class connapp:
|
||||
options = script["options"]
|
||||
thisoptions = {k: v for k, v in options.items() if k in ["prompt", "parallel", "timeout"]}
|
||||
args.update(thisoptions)
|
||||
except:
|
||||
except KeyError:
|
||||
options = None
|
||||
try:
|
||||
size = str(os.get_terminal_size())
|
||||
p = re.search(r'.*columns=([0-9]+)', size)
|
||||
columns = int(p.group(1))
|
||||
except:
|
||||
except (ValueError, OSError):
|
||||
columns = 80
|
||||
|
||||
PANEL_WIDTH = columns
|
||||
@@ -1182,7 +1181,7 @@ class connapp:
|
||||
raise inquirer.errors.ValidationError("", reason="Pick a port between 1-65535, @profile o leave empty")
|
||||
try:
|
||||
port = int(current)
|
||||
except:
|
||||
except ValueError:
|
||||
port = 0
|
||||
if current != "" and not 1 <= int(port) <= 65535:
|
||||
raise inquirer.errors.ValidationError("", reason="Pick a port between 1-65535 or leave empty")
|
||||
@@ -1194,7 +1193,7 @@ class connapp:
|
||||
raise inquirer.errors.ValidationError("", reason="Pick a port between 1-6553/app5, @profile or leave empty")
|
||||
try:
|
||||
port = int(current)
|
||||
except:
|
||||
except ValueError:
|
||||
port = 0
|
||||
if current.startswith("@"):
|
||||
if current[1:] not in self.profiles:
|
||||
@@ -1220,7 +1219,7 @@ class connapp:
|
||||
isdict = False
|
||||
try:
|
||||
isdict = ast.literal_eval(current)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
if not isinstance (isdict, dict):
|
||||
raise inquirer.errors.ValidationError("", reason="Tags should be a python dictionary.".format(current))
|
||||
@@ -1232,7 +1231,7 @@ class connapp:
|
||||
isdict = False
|
||||
try:
|
||||
isdict = ast.literal_eval(current)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
if not isinstance (isdict, dict):
|
||||
raise inquirer.errors.ValidationError("", reason="Tags should be a python dictionary.".format(current))
|
||||
@@ -1316,7 +1315,7 @@ class connapp:
|
||||
defaults["tags"] = ""
|
||||
if "jumphost" not in defaults:
|
||||
defaults["jumphost"] = ""
|
||||
except:
|
||||
except KeyError:
|
||||
defaults = { "host":"", "protocol":"", "port":"", "user":"", "options":"", "logs":"" , "tags":"", "password":"", "jumphost":""}
|
||||
node = {}
|
||||
if edit == None:
|
||||
@@ -1390,7 +1389,7 @@ class connapp:
|
||||
defaults["tags"] = ""
|
||||
if "jumphost" not in defaults:
|
||||
defaults["jumphost"] = ""
|
||||
except:
|
||||
except KeyError:
|
||||
defaults = { "host":"", "protocol":"", "port":"", "user":"", "options":"", "logs":"", "tags": "", "jumphost": ""}
|
||||
profile = {}
|
||||
if edit == None:
|
||||
|
||||
+5
-5
@@ -84,12 +84,12 @@ class node:
|
||||
if profile and config != '':
|
||||
try:
|
||||
setattr(self,key,config.profiles[profile.group(1)][key])
|
||||
except:
|
||||
except KeyError:
|
||||
setattr(self,key,"")
|
||||
elif attr[key] == '' and key == "protocol":
|
||||
try:
|
||||
setattr(self,key,config.profiles["default"][key])
|
||||
except:
|
||||
except (KeyError, AttributeError):
|
||||
setattr(self,key,"ssh")
|
||||
else:
|
||||
setattr(self,key,attr[key])
|
||||
@@ -108,12 +108,12 @@ class node:
|
||||
if profile:
|
||||
try:
|
||||
self.jumphost[key] = config.profiles[profile.group(1)][key]
|
||||
except:
|
||||
except KeyError:
|
||||
self.jumphost[key] = ""
|
||||
elif self.jumphost[key] == '' and key == "protocol":
|
||||
try:
|
||||
self.jumphost[key] = config.profiles["default"][key]
|
||||
except:
|
||||
except KeyError:
|
||||
self.jumphost[key] = "ssh"
|
||||
if isinstance(self.jumphost["password"],list):
|
||||
jumphost_password = []
|
||||
@@ -158,7 +158,7 @@ class node:
|
||||
try:
|
||||
decrypted = decryptor.decrypt(ast.literal_eval(passwd)).decode("utf-8")
|
||||
dpass.append(decrypted)
|
||||
except:
|
||||
except Exception:
|
||||
raise ValueError("Missing or corrupted key")
|
||||
return dpass
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ class RemoteCapture:
|
||||
printer.success("Tcpdump finished capturing packets.")
|
||||
|
||||
self.listener_active = False
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _sendline_until_connected(self, cmd, retries=5, interval=2):
|
||||
@@ -307,7 +307,7 @@ class RemoteCapture:
|
||||
try:
|
||||
self.fake_connection = True
|
||||
socket.create_connection(("localhost", self.local_port), timeout=1).close()
|
||||
except:
|
||||
except OSError:
|
||||
pass
|
||||
self.listener_active = False
|
||||
return
|
||||
@@ -324,7 +324,7 @@ class RemoteCapture:
|
||||
try:
|
||||
self.listener_conn.shutdown(socket.SHUT_RDWR)
|
||||
self.listener_conn.close()
|
||||
except:
|
||||
except OSError:
|
||||
pass
|
||||
if hasattr(self.node, "child"):
|
||||
self.node.child.close(force=True)
|
||||
|
||||
@@ -28,7 +28,7 @@ class sync:
|
||||
self.connapp = connapp
|
||||
try:
|
||||
self.sync = self.connapp.config.config["sync"]
|
||||
except:
|
||||
except KeyError:
|
||||
self.sync = False
|
||||
|
||||
def login(self):
|
||||
@@ -322,7 +322,7 @@ class sync:
|
||||
def config_listener_pre(self, *args, **kwargs):
|
||||
try:
|
||||
self.sync = self.connapp.config.config["sync"]
|
||||
except:
|
||||
except KeyError:
|
||||
self.sync = False
|
||||
return args, kwargs
|
||||
|
||||
|
||||
+2
-2
@@ -62,8 +62,8 @@ class Plugins:
|
||||
if not (isinstance(node.test, ast.Compare) and
|
||||
isinstance(node.test.left, ast.Name) and
|
||||
node.test.left.id == '__name__' and
|
||||
isinstance(node.test.comparators[0], ast.Str) and
|
||||
node.test.comparators[0].s == '__main__'):
|
||||
((hasattr(ast, 'Str') and isinstance(node.test.comparators[0], getattr(ast, 'Str')) and node.test.comparators[0].s == '__main__') or
|
||||
(hasattr(ast, 'Constant') and isinstance(node.test.comparators[0], getattr(ast, 'Constant')) and node.test.comparators[0].value == '__main__'))):
|
||||
return "Only __name__ == __main__ If is allowed"
|
||||
|
||||
elif not isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Import, ast.ImportFrom, ast.Pass)):
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# Tests package
|
||||
@@ -0,0 +1,192 @@
|
||||
"""Shared fixtures for connpy tests.
|
||||
|
||||
All tests use tmp_path to create isolated config/keys.
|
||||
No test touches ~/.config/conn/
|
||||
"""
|
||||
import pytest
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
from Crypto.PublicKey import RSA
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal config data
|
||||
# ---------------------------------------------------------------------------
|
||||
DEFAULT_CONFIG = {
|
||||
"config": {"case": False, "idletime": 30, "fzf": False},
|
||||
"connections": {},
|
||||
"profiles": {
|
||||
"default": {
|
||||
"host": "", "protocol": "ssh", "port": "", "user": "",
|
||||
"password": "", "options": "", "logs": "", "tags": "", "jumphost": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SAMPLE_CONNECTIONS = {
|
||||
"router1": {
|
||||
"host": "10.0.0.1", "protocol": "ssh", "port": "22",
|
||||
"user": "admin", "password": "pass1", "options": "",
|
||||
"logs": "", "tags": "", "jumphost": "", "type": "connection"
|
||||
},
|
||||
"office": {
|
||||
"type": "folder",
|
||||
"server1": {
|
||||
"host": "10.0.1.1", "protocol": "ssh", "port": "",
|
||||
"user": "root", "password": "pass2", "options": "",
|
||||
"logs": "", "tags": "", "jumphost": "", "type": "connection"
|
||||
},
|
||||
"datacenter": {
|
||||
"type": "subfolder",
|
||||
"db1": {
|
||||
"host": "10.0.2.1", "protocol": "ssh", "port": "",
|
||||
"user": "dbadmin", "password": "pass3", "options": "",
|
||||
"logs": "", "tags": "", "jumphost": "", "type": "connection"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SAMPLE_PROFILES = {
|
||||
"default": {
|
||||
"host": "", "protocol": "ssh", "port": "", "user": "",
|
||||
"password": "", "options": "", "logs": "", "tags": "", "jumphost": ""
|
||||
},
|
||||
"office-user": {
|
||||
"host": "", "protocol": "ssh", "port": "", "user": "officeadmin",
|
||||
"password": "officepass", "options": "", "logs": "", "tags": "", "jumphost": ""
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_config_dir(tmp_path):
|
||||
"""Create an isolated config directory with config.json and RSA key."""
|
||||
config_dir = tmp_path / ".config" / "conn"
|
||||
config_dir.mkdir(parents=True)
|
||||
plugins_dir = config_dir / "plugins"
|
||||
plugins_dir.mkdir()
|
||||
|
||||
# Write config.json
|
||||
config_file = config_dir / "config.json"
|
||||
config_file.write_text(json.dumps(DEFAULT_CONFIG, indent=4))
|
||||
os.chmod(str(config_file), 0o600)
|
||||
|
||||
# Write .folder (points to itself)
|
||||
folder_file = config_dir / ".folder"
|
||||
folder_file.write_text(str(config_dir))
|
||||
|
||||
# Generate RSA key
|
||||
key = RSA.generate(2048)
|
||||
key_file = config_dir / ".osk"
|
||||
key_file.write_bytes(key.export_key("PEM"))
|
||||
os.chmod(str(key_file), 0o600)
|
||||
|
||||
return config_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config(tmp_config_dir):
|
||||
"""Create a configfile instance pointing to tmp directory."""
|
||||
from connpy.configfile import configfile
|
||||
conf_path = str(tmp_config_dir / "config.json")
|
||||
key_path = str(tmp_config_dir / ".osk")
|
||||
return configfile(conf=conf_path, key=key_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def populated_config(tmp_config_dir):
|
||||
"""Create a configfile with sample nodes/profiles pre-loaded."""
|
||||
config_file = tmp_config_dir / "config.json"
|
||||
data = {
|
||||
"config": {"case": False, "idletime": 30, "fzf": False},
|
||||
"connections": SAMPLE_CONNECTIONS,
|
||||
"profiles": SAMPLE_PROFILES
|
||||
}
|
||||
config_file.write_text(json.dumps(data, indent=4))
|
||||
from connpy.configfile import configfile
|
||||
return configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pexpect():
|
||||
"""Mock pexpect.spawn for connection tests."""
|
||||
with patch("connpy.core.pexpect") as mock_pexp:
|
||||
child = MagicMock()
|
||||
child.before = b""
|
||||
child.after = b"router#"
|
||||
child.readline.return_value = b""
|
||||
child.child_fd = 3
|
||||
mock_pexp.spawn.return_value = child
|
||||
mock_pexp.EOF = object()
|
||||
mock_pexp.TIMEOUT = object()
|
||||
|
||||
# Also mock fdpexpect
|
||||
with patch("connpy.core.fdpexpect", create=True) as mock_fd:
|
||||
mock_fd.fdspawn.return_value = MagicMock()
|
||||
yield {
|
||||
"pexpect": mock_pexp,
|
||||
"child": child,
|
||||
"fdpexpect": mock_fd
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm():
|
||||
"""Mock litellm.completion for AI tests."""
|
||||
with patch("connpy.ai.completion") as mock_comp:
|
||||
# Create a default response
|
||||
msg = MagicMock()
|
||||
msg.content = "Test response from AI"
|
||||
msg.tool_calls = None
|
||||
msg.role = "assistant"
|
||||
msg.model_dump.return_value = {
|
||||
"role": "assistant",
|
||||
"content": "Test response from AI"
|
||||
}
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = msg
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
response.usage = MagicMock()
|
||||
response.usage.prompt_tokens = 100
|
||||
response.usage.completion_tokens = 50
|
||||
response.usage.total_tokens = 150
|
||||
|
||||
mock_comp.return_value = response
|
||||
|
||||
yield {
|
||||
"completion": mock_comp,
|
||||
"response": response,
|
||||
"message": msg,
|
||||
"choice": choice
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ai_config(tmp_config_dir):
|
||||
"""Create a configfile with AI keys configured for AI tests."""
|
||||
config_file = tmp_config_dir / "config.json"
|
||||
data = {
|
||||
"config": {
|
||||
"case": False, "idletime": 30, "fzf": False,
|
||||
"ai": {
|
||||
"engineer_model": "test/test-model",
|
||||
"engineer_api_key": "test-engineer-key",
|
||||
"architect_model": "test/test-architect",
|
||||
"architect_api_key": "test-architect-key"
|
||||
}
|
||||
},
|
||||
"connections": SAMPLE_CONNECTIONS,
|
||||
"profiles": SAMPLE_PROFILES
|
||||
}
|
||||
config_file.write_text(json.dumps(data, indent=4))
|
||||
from connpy.configfile import configfile
|
||||
return configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
|
||||
@@ -0,0 +1,397 @@
|
||||
"""Tests for connpy.ai module."""
|
||||
import json
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# AI Init tests
|
||||
# =========================================================================
|
||||
|
||||
class TestAIInit:
|
||||
def test_init_with_keys(self, ai_config, mock_litellm):
|
||||
"""Initializes correctly when keys are configured."""
|
||||
from connpy.ai import ai
|
||||
myai = ai(ai_config)
|
||||
assert myai.engineer_model == "test/test-model"
|
||||
assert myai.architect_model == "test/test-architect"
|
||||
|
||||
def test_init_missing_engineer_key(self, config):
|
||||
"""Raises ValueError if engineer key is missing."""
|
||||
from connpy.ai import ai
|
||||
with pytest.raises(ValueError, match="Engineer API key"):
|
||||
ai(config)
|
||||
|
||||
def test_init_missing_architect_key_warns(self, ai_config, capsys, mock_litellm):
|
||||
"""Warns if architect key is missing but doesn't crash."""
|
||||
# Remove architect key
|
||||
ai_config.config["ai"]["architect_api_key"] = None
|
||||
from connpy.ai import ai
|
||||
# Should not raise
|
||||
myai = ai(ai_config)
|
||||
assert myai.architect_key is None
|
||||
|
||||
def test_default_models(self, config):
|
||||
"""Default models are set correctly when not configured."""
|
||||
config.config["ai"] = {"engineer_api_key": "test-key", "architect_api_key": "test-key"}
|
||||
from connpy.ai import ai
|
||||
myai = ai(config)
|
||||
assert "gemini" in myai.engineer_model.lower()
|
||||
assert "claude" in myai.architect_model.lower() or "anthropic" in myai.architect_model.lower()
|
||||
|
||||
def test_init_loads_memory(self, ai_config, tmp_path, mock_litellm):
|
||||
"""Loads long-term memory from file if it exists."""
|
||||
memory_path = os.path.expanduser("~/.config/conn/ai_memory.md")
|
||||
from connpy.ai import ai
|
||||
|
||||
with patch("os.path.exists", side_effect=lambda p: True if p == memory_path else os.path.exists(p)):
|
||||
with patch("builtins.open", side_effect=lambda f, *a, **kw: (
|
||||
__import__("io").StringIO("## Memory\nRouter1 is border router")
|
||||
if f == memory_path else open(f, *a, **kw)
|
||||
)):
|
||||
try:
|
||||
myai = ai(ai_config)
|
||||
except Exception:
|
||||
pass # May fail on other file opens, that's ok
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# register_ai_tool tests
|
||||
# =========================================================================
|
||||
|
||||
class TestRegisterAITool:
|
||||
@pytest.fixture
|
||||
def myai(self, ai_config, mock_litellm):
|
||||
from connpy.ai import ai
|
||||
return ai(ai_config)
|
||||
|
||||
def _make_tool_def(self, name="my_tool"):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": "Test tool",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}
|
||||
|
||||
def test_register_tool_engineer(self, myai):
|
||||
tool_def = self._make_tool_def()
|
||||
myai.register_ai_tool(tool_def, lambda self, **kw: "ok", target="engineer")
|
||||
assert len(myai.external_engineer_tools) == 1
|
||||
assert len(myai.external_architect_tools) == 0
|
||||
|
||||
def test_register_tool_architect(self, myai):
|
||||
tool_def = self._make_tool_def()
|
||||
myai.register_ai_tool(tool_def, lambda self, **kw: "ok", target="architect")
|
||||
assert len(myai.external_architect_tools) == 1
|
||||
assert len(myai.external_engineer_tools) == 0
|
||||
|
||||
def test_register_tool_both(self, myai):
|
||||
tool_def = self._make_tool_def()
|
||||
myai.register_ai_tool(tool_def, lambda self, **kw: "ok", target="both")
|
||||
assert len(myai.external_engineer_tools) == 1
|
||||
assert len(myai.external_architect_tools) == 1
|
||||
|
||||
def test_register_tool_handler(self, myai):
|
||||
tool_def = self._make_tool_def("custom_tool")
|
||||
handler = lambda self, **kw: "result"
|
||||
myai.register_ai_tool(tool_def, handler)
|
||||
assert "custom_tool" in myai.external_tool_handlers
|
||||
assert myai.external_tool_handlers["custom_tool"] is handler
|
||||
|
||||
def test_register_tool_prompt_extension(self, myai):
|
||||
tool_def = self._make_tool_def()
|
||||
myai.register_ai_tool(
|
||||
tool_def, lambda self, **kw: "ok",
|
||||
engineer_prompt="- Custom capability",
|
||||
architect_prompt=" * Custom tool"
|
||||
)
|
||||
assert any("Custom capability" in ext for ext in myai.engineer_prompt_extensions)
|
||||
assert any("Custom tool" in ext for ext in myai.architect_prompt_extensions)
|
||||
|
||||
def test_register_tool_status_formatter(self, myai):
|
||||
tool_def = self._make_tool_def("status_tool")
|
||||
formatter = lambda args: f"[STATUS] {args}"
|
||||
myai.register_ai_tool(tool_def, lambda self, **kw: "ok", status_formatter=formatter)
|
||||
assert "status_tool" in myai.tool_status_formatters
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Dynamic prompts tests
|
||||
# =========================================================================
|
||||
|
||||
class TestDynamicPrompts:
|
||||
@pytest.fixture
|
||||
def myai(self, ai_config, mock_litellm):
|
||||
from connpy.ai import ai
|
||||
return ai(ai_config)
|
||||
|
||||
def test_engineer_prompt_without_extensions(self, myai):
|
||||
prompt = myai.engineer_system_prompt
|
||||
assert "Plugin Capabilities" not in prompt
|
||||
assert "TECHNICAL EXECUTION ENGINE" in prompt
|
||||
|
||||
def test_engineer_prompt_with_extensions(self, myai):
|
||||
myai.engineer_prompt_extensions.append("- AWS Cloud Auditing")
|
||||
prompt = myai.engineer_system_prompt
|
||||
assert "Plugin Capabilities" in prompt
|
||||
assert "AWS Cloud Auditing" in prompt
|
||||
|
||||
def test_architect_prompt_without_extensions(self, myai):
|
||||
prompt = myai.architect_system_prompt
|
||||
assert "Plugin Capabilities" not in prompt
|
||||
assert "STRATEGIC REASONING ENGINE" in prompt
|
||||
|
||||
def test_architect_prompt_with_extensions(self, myai):
|
||||
myai.architect_prompt_extensions.append(" * Custom tool available")
|
||||
prompt = myai.architect_system_prompt
|
||||
assert "Plugin Capabilities" in prompt
|
||||
assert "Custom tool available" in prompt
|
||||
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# _sanitize_messages tests
|
||||
# =========================================================================
|
||||
|
||||
class TestSanitizeMessages:
|
||||
@pytest.fixture
|
||||
def myai(self, ai_config, mock_litellm):
|
||||
from connpy.ai import ai
|
||||
return ai(ai_config)
|
||||
|
||||
def test_sanitize_empty(self, myai):
|
||||
assert myai._sanitize_messages([]) == []
|
||||
|
||||
def test_sanitize_normal_messages(self, myai):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"}
|
||||
]
|
||||
result = myai._sanitize_messages(messages)
|
||||
assert len(result) == 3
|
||||
|
||||
def test_sanitize_removes_orphan_tool_calls(self, myai):
|
||||
"""Tool calls at the end without responses are removed."""
|
||||
messages = [
|
||||
{"role": "user", "content": "do something"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [
|
||||
{"id": "tc1", "function": {"name": "list_nodes", "arguments": "{}"}}
|
||||
]}
|
||||
# No tool response follows!
|
||||
]
|
||||
result = myai._sanitize_messages(messages)
|
||||
assert len(result) == 1 # Only user message
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_sanitize_removes_orphan_tool_responses(self, myai):
|
||||
"""Tool responses without preceding tool_calls are removed."""
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "tool", "tool_call_id": "tc1", "name": "list_nodes", "content": "[]"}
|
||||
]
|
||||
result = myai._sanitize_messages(messages)
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_sanitize_preserves_valid_tool_pairs(self, myai):
|
||||
"""Valid assistant+tool_calls followed by tool responses are preserved."""
|
||||
messages = [
|
||||
{"role": "user", "content": "list nodes"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [
|
||||
{"id": "tc1", "function": {"name": "list_nodes", "arguments": "{}"}}
|
||||
]},
|
||||
{"role": "tool", "tool_call_id": "tc1", "name": "list_nodes", "content": "[\"r1\"]"},
|
||||
{"role": "assistant", "content": "Found r1"}
|
||||
]
|
||||
result = myai._sanitize_messages(messages)
|
||||
assert len(result) == 4
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# _truncate tests
|
||||
# =========================================================================
|
||||
|
||||
class TestTruncate:
|
||||
@pytest.fixture
|
||||
def myai(self, ai_config, mock_litellm):
|
||||
from connpy.ai import ai
|
||||
return ai(ai_config)
|
||||
|
||||
def test_truncate_short_text(self, myai):
|
||||
text = "short text"
|
||||
assert myai._truncate(text) == text
|
||||
|
||||
def test_truncate_long_text(self, myai):
|
||||
text = "x" * 100000
|
||||
result = myai._truncate(text)
|
||||
assert len(result) < 100000
|
||||
assert "[... OUTPUT TRUNCATED ...]" in result
|
||||
|
||||
def test_truncate_custom_limit(self, myai):
|
||||
text = "x" * 1000
|
||||
result = myai._truncate(text, limit=500)
|
||||
assert len(result) < 1000
|
||||
assert "[... OUTPUT TRUNCATED ...]" in result
|
||||
|
||||
def test_truncate_preserves_head_and_tail(self, myai):
|
||||
text = "HEAD" + "x" * 100000 + "TAIL"
|
||||
result = myai._truncate(text)
|
||||
assert result.startswith("HEAD")
|
||||
assert result.endswith("TAIL")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tool methods tests
|
||||
# =========================================================================
|
||||
|
||||
class TestToolMethods:
|
||||
@pytest.fixture
|
||||
def myai(self, ai_config, mock_litellm):
|
||||
from connpy.ai import ai
|
||||
return ai(ai_config)
|
||||
|
||||
def test_list_nodes_tool_found(self, myai):
|
||||
result = myai.list_nodes_tool("router.*")
|
||||
parsed = json.loads(result)
|
||||
assert "router1" in str(parsed)
|
||||
|
||||
def test_list_nodes_tool_not_found(self, myai):
|
||||
result = myai.list_nodes_tool("nonexistent_pattern_xyz")
|
||||
assert "No nodes found" in result
|
||||
|
||||
def test_get_node_info_masks_password(self, myai):
|
||||
result = myai.get_node_info_tool("router1")
|
||||
parsed = json.loads(result)
|
||||
assert parsed["password"] == "***"
|
||||
|
||||
def test_is_safe_command_show(self, myai):
|
||||
assert myai._is_safe_command("show running-config") == True
|
||||
assert myai._is_safe_command("show ip int brief") == True
|
||||
|
||||
def test_is_safe_command_config(self, myai):
|
||||
assert myai._is_safe_command("config t") == False
|
||||
assert myai._is_safe_command("write memory") == False
|
||||
|
||||
def test_is_safe_command_ls(self, myai):
|
||||
assert myai._is_safe_command("ls -la") == True
|
||||
|
||||
def test_is_safe_command_ping(self, myai):
|
||||
assert myai._is_safe_command("ping 10.0.0.1") == True
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# manage_memory_tool tests
|
||||
# =========================================================================
|
||||
|
||||
class TestManageMemory:
|
||||
@pytest.fixture
|
||||
def myai(self, ai_config, mock_litellm, tmp_path):
|
||||
from connpy.ai import ai
|
||||
myai = ai(ai_config)
|
||||
myai.memory_path = str(tmp_path / "ai_memory.md")
|
||||
return myai
|
||||
|
||||
def test_manage_memory_append(self, myai):
|
||||
result = myai.manage_memory_tool("Router1 is border router", action="append")
|
||||
assert "successfully" in result.lower()
|
||||
assert os.path.exists(myai.memory_path)
|
||||
content = open(myai.memory_path).read()
|
||||
assert "Router1 is border router" in content
|
||||
|
||||
def test_manage_memory_replace(self, myai):
|
||||
myai.manage_memory_tool("old content", action="append")
|
||||
myai.manage_memory_tool("new content only", action="replace")
|
||||
content = open(myai.memory_path).read()
|
||||
assert "new content only" in content
|
||||
assert "old content" not in content
|
||||
|
||||
def test_manage_memory_empty_content(self, myai):
|
||||
result = myai.manage_memory_tool("", action="append")
|
||||
assert "error" in result.lower() or "Error" in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# ask() with mock LLM tests
|
||||
# =========================================================================
|
||||
|
||||
class TestAsk:
|
||||
@pytest.fixture
|
||||
def myai(self, ai_config, mock_litellm):
|
||||
from connpy.ai import ai
|
||||
return ai(ai_config)
|
||||
|
||||
def test_ask_basic_response(self, myai, mock_litellm):
|
||||
result = myai.ask("hello", stream=False)
|
||||
assert "response" in result
|
||||
assert "chat_history" in result
|
||||
assert "usage" in result
|
||||
assert result["response"] == "Test response from AI"
|
||||
|
||||
def test_ask_sticky_brain_engineer(self, myai, mock_litellm):
|
||||
result = myai.ask("show me the routers", stream=False)
|
||||
assert result["responder"] == "engineer"
|
||||
|
||||
def test_ask_explicit_architect(self, myai, mock_litellm):
|
||||
result = myai.ask("architect: review the network design", stream=False)
|
||||
assert result["responder"] == "architect"
|
||||
|
||||
def test_ask_returns_usage(self, myai, mock_litellm):
|
||||
result = myai.ask("test", stream=False)
|
||||
assert result["usage"]["total"] > 0
|
||||
|
||||
def test_ask_with_chat_history(self, myai, mock_litellm):
|
||||
history = [
|
||||
{"role": "user", "content": "previous question"},
|
||||
{"role": "assistant", "content": "previous answer"}
|
||||
]
|
||||
result = myai.ask("follow up", chat_history=history, stream=False)
|
||||
assert result["response"] is not None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# _get_engineer_tools / _get_architect_tools tests
|
||||
# =========================================================================
|
||||
|
||||
class TestToolDefinitions:
|
||||
@pytest.fixture
|
||||
def myai(self, ai_config, mock_litellm):
|
||||
from connpy.ai import ai
|
||||
return ai(ai_config)
|
||||
|
||||
def test_engineer_tools_include_core(self, myai):
|
||||
tools = myai._get_engineer_tools()
|
||||
names = [t["function"]["name"] for t in tools]
|
||||
assert "list_nodes" in names
|
||||
assert "run_commands" in names
|
||||
assert "get_node_info" in names
|
||||
assert "consult_architect" in names
|
||||
assert "escalate_to_architect" in names
|
||||
|
||||
def test_engineer_tools_include_external(self, myai):
|
||||
myai.external_engineer_tools.append({
|
||||
"type": "function",
|
||||
"function": {"name": "custom_tool", "description": "test", "parameters": {}}
|
||||
})
|
||||
tools = myai._get_engineer_tools()
|
||||
names = [t["function"]["name"] for t in tools]
|
||||
assert "custom_tool" in names
|
||||
|
||||
def test_architect_tools_include_core(self, myai):
|
||||
tools = myai._get_architect_tools()
|
||||
names = [t["function"]["name"] for t in tools]
|
||||
assert "delegate_to_engineer" in names
|
||||
assert "return_to_engineer" in names
|
||||
assert "manage_memory_tool" in names
|
||||
|
||||
def test_architect_tools_include_external(self, myai):
|
||||
myai.external_architect_tools.append({
|
||||
"type": "function",
|
||||
"function": {"name": "arch_tool", "description": "test", "parameters": {}}
|
||||
})
|
||||
tools = myai._get_architect_tools()
|
||||
names = [t["function"]["name"] for t in tools]
|
||||
assert "arch_tool" in names
|
||||
@@ -0,0 +1,268 @@
|
||||
"""Tests for connpy.api module — Flask routes."""
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_client(populated_config):
|
||||
"""Create a Flask test client with a populated config."""
|
||||
from connpy.api import app
|
||||
app.custom_config = populated_config
|
||||
app.config["TESTING"] = True
|
||||
with app.test_client() as client:
|
||||
yield client
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Root endpoint
|
||||
# =========================================================================
|
||||
|
||||
class TestRootEndpoint:
|
||||
def test_root_returns_welcome(self, api_client):
|
||||
response = api_client.get("/")
|
||||
data = response.get_json()
|
||||
assert response.status_code == 200
|
||||
assert "Welcome" in data["message"]
|
||||
assert "version" in data
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# /list_nodes endpoint
|
||||
# =========================================================================
|
||||
|
||||
class TestListNodes:
|
||||
def test_list_nodes_no_filter(self, api_client):
|
||||
response = api_client.post("/list_nodes", json={})
|
||||
data = response.get_json()
|
||||
assert response.status_code == 200
|
||||
assert isinstance(data, list)
|
||||
assert "router1" in data
|
||||
|
||||
def test_list_nodes_with_filter(self, api_client):
|
||||
response = api_client.post("/list_nodes", json={"filter": "router.*"})
|
||||
data = response.get_json()
|
||||
assert "router1" in data
|
||||
assert all("router" in n or "Router" in n for n in data)
|
||||
|
||||
def test_list_nodes_case_insensitive(self, api_client):
|
||||
"""Filter is lowercased when case=false."""
|
||||
response = api_client.post("/list_nodes", json={"filter": "ROUTER.*"})
|
||||
data = response.get_json()
|
||||
# Should still match since the filter gets lowercased
|
||||
assert isinstance(data, list)
|
||||
|
||||
def test_list_nodes_no_body(self, api_client):
|
||||
"""No body returns all nodes."""
|
||||
response = api_client.post("/list_nodes",
|
||||
data="",
|
||||
content_type="application/json")
|
||||
data = response.get_json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# /get_nodes endpoint
|
||||
# =========================================================================
|
||||
|
||||
class TestGetNodes:
|
||||
def test_get_nodes_no_filter(self, api_client):
|
||||
response = api_client.post("/get_nodes", json={})
|
||||
data = response.get_json()
|
||||
assert response.status_code == 200
|
||||
assert isinstance(data, dict)
|
||||
assert "router1" in data
|
||||
|
||||
def test_get_nodes_with_filter(self, api_client):
|
||||
response = api_client.post("/get_nodes", json={"filter": "router.*"})
|
||||
data = response.get_json()
|
||||
assert "router1" in data
|
||||
assert "host" in data["router1"]
|
||||
|
||||
def test_get_nodes_has_attributes(self, api_client):
|
||||
response = api_client.post("/get_nodes", json={"filter": "router1"})
|
||||
data = response.get_json()
|
||||
if "router1" in data:
|
||||
assert "host" in data["router1"]
|
||||
assert "protocol" in data["router1"]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# /run_commands endpoint
|
||||
# =========================================================================
|
||||
|
||||
class TestRunCommands:
|
||||
def test_missing_action(self, api_client):
|
||||
response = api_client.post("/run_commands", json={
|
||||
"nodes": "router1",
|
||||
"commands": ["show version"]
|
||||
})
|
||||
data = response.get_json()
|
||||
assert "DataError" in data
|
||||
assert "action" in data["DataError"]
|
||||
|
||||
def test_missing_nodes(self, api_client):
|
||||
response = api_client.post("/run_commands", json={
|
||||
"action": "run",
|
||||
"commands": ["show version"]
|
||||
})
|
||||
data = response.get_json()
|
||||
assert "DataError" in data
|
||||
assert "nodes" in data["DataError"]
|
||||
|
||||
def test_missing_commands(self, api_client):
|
||||
response = api_client.post("/run_commands", json={
|
||||
"action": "run",
|
||||
"nodes": "router1"
|
||||
})
|
||||
data = response.get_json()
|
||||
assert "DataError" in data
|
||||
assert "commands" in data["DataError"]
|
||||
|
||||
def test_wrong_action(self, api_client):
|
||||
response = api_client.post("/run_commands", json={
|
||||
"action": "invalid",
|
||||
"nodes": "router1",
|
||||
"commands": ["show version"]
|
||||
})
|
||||
data = response.get_json()
|
||||
assert "DataError" in data
|
||||
assert "Wrong action" in data["DataError"]
|
||||
|
||||
@patch("connpy.api.nodes")
|
||||
def test_run_action(self, mock_nodes_cls, api_client):
|
||||
"""action=run executes and returns output."""
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.run.return_value = {"router1": "Router v1.0"}
|
||||
mock_nodes_cls.return_value = mock_instance
|
||||
|
||||
response = api_client.post("/run_commands", json={
|
||||
"action": "run",
|
||||
"nodes": "router1",
|
||||
"commands": ["show version"]
|
||||
})
|
||||
data = response.get_json()
|
||||
assert "router1" in data
|
||||
|
||||
@patch("connpy.api.nodes")
|
||||
def test_test_action(self, mock_nodes_cls, api_client):
|
||||
"""action=test returns result + output."""
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.test.return_value = {"router1": {"expected": True}}
|
||||
mock_instance.output = {"router1": "output text"}
|
||||
mock_nodes_cls.return_value = mock_instance
|
||||
|
||||
response = api_client.post("/run_commands", json={
|
||||
"action": "test",
|
||||
"nodes": "router1",
|
||||
"commands": ["show version"],
|
||||
"expected": "Router"
|
||||
})
|
||||
data = response.get_json()
|
||||
assert "result" in data
|
||||
assert "output" in data
|
||||
|
||||
@patch("connpy.api.nodes")
|
||||
def test_run_with_options(self, mock_nodes_cls, api_client):
|
||||
"""Options get passed through."""
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.run.return_value = {"router1": "ok"}
|
||||
mock_nodes_cls.return_value = mock_instance
|
||||
|
||||
response = api_client.post("/run_commands", json={
|
||||
"action": "run",
|
||||
"nodes": "router1",
|
||||
"commands": ["show version"],
|
||||
"options": {"timeout": 30, "parallel": 5}
|
||||
})
|
||||
assert response.status_code == 200
|
||||
|
||||
@patch("connpy.api.nodes")
|
||||
def test_run_folder_nodes(self, mock_nodes_cls, api_client):
|
||||
"""Nodes with @ prefix are resolved as folders."""
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.run.return_value = {"server1@office": "ok"}
|
||||
mock_nodes_cls.return_value = mock_instance
|
||||
|
||||
response = api_client.post("/run_commands", json={
|
||||
"action": "run",
|
||||
"nodes": "@office",
|
||||
"commands": ["ls -la"]
|
||||
})
|
||||
assert response.status_code == 200
|
||||
|
||||
@patch("connpy.api.nodes")
|
||||
def test_run_list_nodes(self, mock_nodes_cls, api_client):
|
||||
"""List of nodes is resolved correctly."""
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.run.return_value = {"router1": "ok", "server1@office": "ok"}
|
||||
mock_nodes_cls.return_value = mock_instance
|
||||
|
||||
response = api_client.post("/run_commands", json={
|
||||
"action": "run",
|
||||
"nodes": ["router1", "server1@office"],
|
||||
"commands": ["show version"]
|
||||
})
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# /ask_ai endpoint
|
||||
# =========================================================================
|
||||
|
||||
class TestAskAI:
|
||||
@patch("connpy.api.myai")
|
||||
def test_ask_ai(self, mock_ai_cls, api_client):
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.ask.return_value = {"response": "AI says hello"}
|
||||
mock_ai_cls.return_value = mock_instance
|
||||
|
||||
response = api_client.post("/ask_ai", json={
|
||||
"input": "list my routers"
|
||||
})
|
||||
data = response.get_json()
|
||||
assert data is not None
|
||||
|
||||
@patch("connpy.api.myai")
|
||||
def test_ask_ai_with_dryrun(self, mock_ai_cls, api_client):
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.ask.return_value = {"response": "dry run"}
|
||||
mock_ai_cls.return_value = mock_instance
|
||||
|
||||
response = api_client.post("/ask_ai", json={
|
||||
"input": "test",
|
||||
"dryrun": True
|
||||
})
|
||||
assert response.status_code == 200
|
||||
|
||||
@patch("connpy.api.myai")
|
||||
def test_ask_ai_with_history(self, mock_ai_cls, api_client):
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.ask.return_value = {"response": "with history"}
|
||||
mock_ai_cls.return_value = mock_instance
|
||||
|
||||
response = api_client.post("/ask_ai", json={
|
||||
"input": "follow up",
|
||||
"chat_history": [
|
||||
{"role": "user", "content": "previous"},
|
||||
{"role": "assistant", "content": "answer"}
|
||||
]
|
||||
})
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# /confirm endpoint
|
||||
# =========================================================================
|
||||
|
||||
class TestConfirm:
|
||||
@patch("connpy.api.myai")
|
||||
def test_confirm(self, mock_ai_cls, api_client):
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.confirm.return_value = True
|
||||
mock_ai_cls.return_value = mock_instance
|
||||
|
||||
response = api_client.post("/confirm", json={
|
||||
"input": "yes"
|
||||
})
|
||||
assert response.status_code == 200
|
||||
@@ -0,0 +1,182 @@
|
||||
"""Tests for connpy.completion module."""
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
from connpy.completion import _getallnodes, _getallfolders, _getcwd, _get_plugins
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# _getallnodes tests
|
||||
# =========================================================================
|
||||
|
||||
class TestGetAllNodes:
|
||||
def test_flat_nodes(self):
|
||||
"""Nodes without folders."""
|
||||
config = {
|
||||
"connections": {
|
||||
"router1": {"type": "connection"},
|
||||
"router2": {"type": "connection"}
|
||||
}
|
||||
}
|
||||
nodes = _getallnodes(config)
|
||||
assert "router1" in nodes
|
||||
assert "router2" in nodes
|
||||
|
||||
def test_nested_nodes(self):
|
||||
"""Nodes in folders and subfolders have correct format."""
|
||||
config = {
|
||||
"connections": {
|
||||
"router1": {"type": "connection"},
|
||||
"office": {
|
||||
"type": "folder",
|
||||
"server1": {"type": "connection"},
|
||||
"datacenter": {
|
||||
"type": "subfolder",
|
||||
"db1": {"type": "connection"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes = _getallnodes(config)
|
||||
assert "router1" in nodes
|
||||
assert "server1@office" in nodes
|
||||
assert "db1@datacenter@office" in nodes
|
||||
|
||||
def test_empty_connections(self):
|
||||
config = {"connections": {}}
|
||||
nodes = _getallnodes(config)
|
||||
assert nodes == []
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# _getallfolders tests
|
||||
# =========================================================================
|
||||
|
||||
class TestGetAllFolders:
|
||||
def test_basic_folders(self):
|
||||
config = {
|
||||
"connections": {
|
||||
"office": {"type": "folder"},
|
||||
"home": {"type": "folder"}
|
||||
}
|
||||
}
|
||||
folders = _getallfolders(config)
|
||||
assert "@office" in folders
|
||||
assert "@home" in folders
|
||||
|
||||
def test_with_subfolders(self):
|
||||
config = {
|
||||
"connections": {
|
||||
"office": {
|
||||
"type": "folder",
|
||||
"datacenter": {"type": "subfolder"},
|
||||
"server1": {"type": "connection"}
|
||||
}
|
||||
}
|
||||
}
|
||||
folders = _getallfolders(config)
|
||||
assert "@office" in folders
|
||||
assert "@datacenter@office" in folders
|
||||
|
||||
def test_empty(self):
|
||||
config = {"connections": {}}
|
||||
folders = _getallfolders(config)
|
||||
assert folders == []
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# _getcwd tests
|
||||
# =========================================================================
|
||||
|
||||
class TestGetCwd:
|
||||
def test_current_dir(self, tmp_path, monkeypatch):
|
||||
"""Lists files in current directory."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "file1.txt").touch()
|
||||
(tmp_path / "file2.py").touch()
|
||||
subdir = tmp_path / "subdir"
|
||||
subdir.mkdir()
|
||||
|
||||
result = _getcwd(["run", "run"], "run")
|
||||
# Should list files
|
||||
assert any("file1.txt" in r for r in result)
|
||||
assert any("subdir/" in r for r in result)
|
||||
|
||||
def test_specific_path(self, tmp_path, monkeypatch):
|
||||
"""Lists files matching a partial path."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "script.yaml").touch()
|
||||
(tmp_path / "script2.yaml").touch()
|
||||
|
||||
result = _getcwd(["run", "script"], "run")
|
||||
assert any("script" in r for r in result)
|
||||
|
||||
def test_folder_only(self, tmp_path, monkeypatch):
|
||||
"""folderonly=True returns only directories."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "file.txt").touch()
|
||||
subdir = tmp_path / "mydir"
|
||||
subdir.mkdir()
|
||||
|
||||
result = _getcwd(["export", "export"], "export", folderonly=True)
|
||||
files_in_result = [r for r in result if "file.txt" in r]
|
||||
assert len(files_in_result) == 0
|
||||
dirs_in_result = [r for r in result if "mydir" in r]
|
||||
assert len(dirs_in_result) > 0
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# _get_plugins tests
|
||||
# =========================================================================
|
||||
|
||||
class TestGetPlugins:
|
||||
def test_get_plugins_disable(self, tmp_path):
|
||||
"""--disable returns enabled plugins."""
|
||||
plugin_dir = tmp_path / "plugins"
|
||||
plugin_dir.mkdir()
|
||||
(plugin_dir / "active.py").touch()
|
||||
(plugin_dir / "disabled.py.bkp").touch()
|
||||
|
||||
result = _get_plugins("--disable", str(tmp_path))
|
||||
assert "active" in result
|
||||
assert "disabled" not in result
|
||||
|
||||
def test_get_plugins_enable(self, tmp_path):
|
||||
"""--enable returns disabled plugins."""
|
||||
plugin_dir = tmp_path / "plugins"
|
||||
plugin_dir.mkdir()
|
||||
(plugin_dir / "active.py").touch()
|
||||
(plugin_dir / "disabled.py.bkp").touch()
|
||||
|
||||
result = _get_plugins("--enable", str(tmp_path))
|
||||
assert "disabled" in result
|
||||
assert "active" not in result
|
||||
|
||||
def test_get_plugins_del(self, tmp_path):
|
||||
"""--del returns all plugins."""
|
||||
plugin_dir = tmp_path / "plugins"
|
||||
plugin_dir.mkdir()
|
||||
(plugin_dir / "active.py").touch()
|
||||
(plugin_dir / "disabled.py.bkp").touch()
|
||||
|
||||
result = _get_plugins("--del", str(tmp_path))
|
||||
assert "active" in result
|
||||
assert "disabled" in result
|
||||
|
||||
def test_get_plugins_all(self, tmp_path):
|
||||
"""'all' returns dict with paths."""
|
||||
plugin_dir = tmp_path / "plugins"
|
||||
plugin_dir.mkdir()
|
||||
(plugin_dir / "myplugin.py").touch()
|
||||
|
||||
result = _get_plugins("all", str(tmp_path))
|
||||
assert isinstance(result, dict)
|
||||
assert "myplugin" in result
|
||||
|
||||
def test_get_plugins_empty_dir(self, tmp_path):
|
||||
"""Empty plugins directory returns empty list."""
|
||||
plugin_dir = tmp_path / "plugins"
|
||||
plugin_dir.mkdir()
|
||||
|
||||
result = _get_plugins("--disable", str(tmp_path))
|
||||
assert result == []
|
||||
@@ -0,0 +1,376 @@
|
||||
"""Tests for connpy.configfile module."""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import pytest
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class TestConfigfileInit:
|
||||
def test_creates_default_config(self, tmp_config_dir):
|
||||
"""Creates config.json with defaults when it doesn't exist."""
|
||||
config_file = tmp_config_dir / "config.json"
|
||||
config_file.unlink() # Remove existing
|
||||
key_file = tmp_config_dir / ".osk"
|
||||
|
||||
from connpy.configfile import configfile
|
||||
conf = configfile(conf=str(config_file), key=str(key_file))
|
||||
|
||||
assert config_file.exists()
|
||||
assert conf.config["case"] == False
|
||||
assert conf.config["idletime"] == 30
|
||||
assert "default" in conf.profiles
|
||||
|
||||
def test_creates_rsa_key(self, tmp_config_dir):
|
||||
"""Generates RSA key when it doesn't exist."""
|
||||
key_file = tmp_config_dir / ".osk"
|
||||
key_file.unlink() # Remove existing
|
||||
|
||||
from connpy.configfile import configfile
|
||||
conf = configfile(conf=str(tmp_config_dir / "config.json"), key=str(key_file))
|
||||
|
||||
assert key_file.exists()
|
||||
assert conf.privatekey is not None
|
||||
assert conf.publickey is not None
|
||||
|
||||
def test_loads_existing_config(self, config):
|
||||
"""Loads correctly from existing config."""
|
||||
assert config.config is not None
|
||||
assert config.connections is not None
|
||||
assert config.profiles is not None
|
||||
|
||||
def test_config_file_permissions(self, tmp_config_dir):
|
||||
"""Config is created with 0o600 permissions."""
|
||||
config_file = tmp_config_dir / "config.json"
|
||||
config_file.unlink()
|
||||
|
||||
from connpy.configfile import configfile
|
||||
configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
|
||||
|
||||
stat = os.stat(str(config_file))
|
||||
assert oct(stat.st_mode & 0o777) == oct(0o600)
|
||||
|
||||
def test_custom_paths(self, tmp_path):
|
||||
"""Accepts custom paths for conf and key."""
|
||||
config_dir = tmp_path / "custom"
|
||||
config_dir.mkdir()
|
||||
(config_dir / "plugins").mkdir()
|
||||
|
||||
# Write .folder for the config dir
|
||||
dot_folder = tmp_path / ".config" / "conn"
|
||||
dot_folder.mkdir(parents=True, exist_ok=True)
|
||||
(dot_folder / ".folder").write_text(str(config_dir))
|
||||
(dot_folder / "plugins").mkdir(exist_ok=True)
|
||||
|
||||
conf_path = str(config_dir / "my_config.json")
|
||||
key_path = str(config_dir / "my_key")
|
||||
|
||||
from connpy.configfile import configfile
|
||||
conf = configfile(conf=conf_path, key=key_path)
|
||||
|
||||
assert conf.file == conf_path
|
||||
assert conf.key == key_path
|
||||
|
||||
|
||||
class TestEncryption:
|
||||
def test_encrypt_password(self, config):
|
||||
"""Encrypts and produces b'...' format."""
|
||||
encrypted = config.encrypt("mysecret")
|
||||
assert encrypted.startswith("b'") or encrypted.startswith('b"')
|
||||
|
||||
def test_encrypt_decrypt_roundtrip(self, config):
|
||||
"""Encrypt then decrypt returns original."""
|
||||
from Crypto.PublicKey import RSA
|
||||
from Crypto.Cipher import PKCS1_OAEP
|
||||
import ast
|
||||
|
||||
original = "super_secret_password"
|
||||
encrypted = config.encrypt(original)
|
||||
|
||||
# Decrypt
|
||||
with open(config.key) as f:
|
||||
key = RSA.import_key(f.read())
|
||||
decryptor = PKCS1_OAEP.new(key)
|
||||
decrypted = decryptor.decrypt(ast.literal_eval(encrypted)).decode("utf-8")
|
||||
assert decrypted == original
|
||||
|
||||
|
||||
class TestExplodeUnique:
|
||||
def test_simple_node(self, config):
|
||||
result = config._explode_unique("router1")
|
||||
assert result == {"id": "router1"}
|
||||
|
||||
def test_node_with_folder(self, config):
|
||||
result = config._explode_unique("r1@office")
|
||||
assert result == {"id": "r1", "folder": "office"}
|
||||
|
||||
def test_node_with_subfolder(self, config):
|
||||
result = config._explode_unique("r1@dc@office")
|
||||
assert result == {"id": "r1", "folder": "office", "subfolder": "dc"}
|
||||
|
||||
def test_folder_only(self, config):
|
||||
result = config._explode_unique("@office")
|
||||
assert result == {"folder": "office"}
|
||||
|
||||
def test_subfolder_only(self, config):
|
||||
result = config._explode_unique("@dc@office")
|
||||
assert result == {"folder": "office", "subfolder": "dc"}
|
||||
|
||||
def test_too_deep(self, config):
|
||||
result = config._explode_unique("a@b@c@d")
|
||||
assert result == False
|
||||
|
||||
def test_empty_folder(self, config):
|
||||
result = config._explode_unique("a@")
|
||||
assert result == False
|
||||
|
||||
def test_empty_subfolder(self, config):
|
||||
result = config._explode_unique("a@@office")
|
||||
assert result == False
|
||||
|
||||
|
||||
class TestCRUDNodes:
|
||||
def test_add_node_root(self, config):
|
||||
config._connections_add(
|
||||
id="router1", host="10.0.0.1", protocol="ssh",
|
||||
port="22", user="admin", password="pass", options="",
|
||||
logs="", tags="", jumphost=""
|
||||
)
|
||||
assert "router1" in config.connections
|
||||
assert config.connections["router1"]["host"] == "10.0.0.1"
|
||||
|
||||
def test_add_node_folder(self, config):
|
||||
config._folder_add(folder="office")
|
||||
config._connections_add(
|
||||
id="server1", folder="office", host="10.0.1.1",
|
||||
protocol="ssh", port="", user="root", password="pass",
|
||||
options="", logs="", tags="", jumphost=""
|
||||
)
|
||||
assert "server1" in config.connections["office"]
|
||||
|
||||
def test_add_node_subfolder(self, config):
|
||||
config._folder_add(folder="office")
|
||||
config._folder_add(folder="office", subfolder="dc")
|
||||
config._connections_add(
|
||||
id="db1", folder="office", subfolder="dc", host="10.0.2.1",
|
||||
protocol="ssh", port="", user="dbadmin", password="pass",
|
||||
options="", logs="", tags="", jumphost=""
|
||||
)
|
||||
assert "db1" in config.connections["office"]["dc"]
|
||||
|
||||
def test_del_node_root(self, config):
|
||||
config._connections_add(
|
||||
id="router1", host="10.0.0.1", protocol="ssh",
|
||||
port="", user="", password="", options="",
|
||||
logs="", tags="", jumphost=""
|
||||
)
|
||||
config._connections_del(id="router1")
|
||||
assert "router1" not in config.connections
|
||||
|
||||
def test_del_node_folder(self, config):
|
||||
config._folder_add(folder="office")
|
||||
config._connections_add(
|
||||
id="server1", folder="office", host="10.0.1.1",
|
||||
protocol="ssh", port="", user="", password="",
|
||||
options="", logs="", tags="", jumphost=""
|
||||
)
|
||||
config._connections_del(id="server1", folder="office")
|
||||
assert "server1" not in config.connections["office"]
|
||||
|
||||
def test_add_folder(self, config):
|
||||
config._folder_add(folder="office")
|
||||
assert "office" in config.connections
|
||||
assert config.connections["office"]["type"] == "folder"
|
||||
|
||||
def test_add_subfolder(self, config):
|
||||
config._folder_add(folder="office")
|
||||
config._folder_add(folder="office", subfolder="dc")
|
||||
assert "dc" in config.connections["office"]
|
||||
assert config.connections["office"]["dc"]["type"] == "subfolder"
|
||||
|
||||
def test_del_folder(self, config):
|
||||
config._folder_add(folder="office")
|
||||
config._folder_del(folder="office")
|
||||
assert "office" not in config.connections
|
||||
|
||||
def test_del_subfolder(self, config):
|
||||
config._folder_add(folder="office")
|
||||
config._folder_add(folder="office", subfolder="dc")
|
||||
config._folder_del(folder="office", subfolder="dc")
|
||||
assert "dc" not in config.connections["office"]
|
||||
|
||||
|
||||
class TestCRUDProfiles:
|
||||
def test_add_profile(self, config):
|
||||
config._profiles_add(
|
||||
id="myprofile", host="", protocol="telnet",
|
||||
port="23", user="user1", password="pass1",
|
||||
options="", logs="", tags="", jumphost=""
|
||||
)
|
||||
assert "myprofile" in config.profiles
|
||||
assert config.profiles["myprofile"]["protocol"] == "telnet"
|
||||
|
||||
def test_del_profile(self, config):
|
||||
config._profiles_add(
|
||||
id="temp", host="", protocol="ssh", port="",
|
||||
user="", password="", options="", logs="", tags="", jumphost=""
|
||||
)
|
||||
config._profiles_del(id="temp")
|
||||
assert "temp" not in config.profiles
|
||||
|
||||
def test_default_profile_exists(self, config):
|
||||
assert "default" in config.profiles
|
||||
|
||||
|
||||
class TestGetItem:
|
||||
def test_getitem_node(self, populated_config):
|
||||
node = populated_config.getitem("router1")
|
||||
assert node["host"] == "10.0.0.1"
|
||||
assert "type" not in node # type is stripped
|
||||
|
||||
def test_getitem_folder(self, populated_config):
|
||||
nodes = populated_config.getitem("@office")
|
||||
# Should contain server1@office but NOT datacenter (subfolder)
|
||||
assert "server1@office" in nodes
|
||||
assert all("type" not in v for v in nodes.values())
|
||||
|
||||
def test_getitem_subfolder(self, populated_config):
|
||||
nodes = populated_config.getitem("@datacenter@office")
|
||||
assert "db1@datacenter@office" in nodes
|
||||
|
||||
def test_getitem_node_in_folder(self, populated_config):
|
||||
node = populated_config.getitem("server1@office")
|
||||
assert node["host"] == "10.0.1.1"
|
||||
|
||||
def test_getitem_node_in_subfolder(self, populated_config):
|
||||
node = populated_config.getitem("db1@datacenter@office")
|
||||
assert node["host"] == "10.0.2.1"
|
||||
|
||||
def test_getitem_with_profile_extraction(self, tmp_config_dir):
|
||||
"""extract=True resolves @profile references."""
|
||||
config_file = tmp_config_dir / "config.json"
|
||||
data = {
|
||||
"config": {"case": False, "idletime": 30, "fzf": False},
|
||||
"connections": {
|
||||
"router1": {
|
||||
"host": "10.0.0.1", "protocol": "ssh", "port": "",
|
||||
"user": "@office-user", "password": "@office-user",
|
||||
"options": "", "logs": "", "tags": "", "jumphost": "",
|
||||
"type": "connection"
|
||||
}
|
||||
},
|
||||
"profiles": {
|
||||
"default": {"host": "", "protocol": "ssh", "port": "",
|
||||
"user": "", "password": "", "options": "",
|
||||
"logs": "", "tags": "", "jumphost": ""},
|
||||
"office-user": {"host": "", "protocol": "ssh", "port": "",
|
||||
"user": "officeadmin", "password": "officepass",
|
||||
"options": "", "logs": "", "tags": "", "jumphost": ""}
|
||||
}
|
||||
}
|
||||
config_file.write_text(json.dumps(data, indent=4))
|
||||
|
||||
from connpy.configfile import configfile
|
||||
conf = configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
|
||||
|
||||
node = conf.getitem("router1", extract=True)
|
||||
assert node["user"] == "officeadmin"
|
||||
assert node["password"] == "officepass"
|
||||
|
||||
def test_getitems_multiple(self, populated_config):
|
||||
nodes = populated_config.getitems(["router1", "server1@office"])
|
||||
assert "router1" in nodes
|
||||
assert "server1@office" in nodes
|
||||
|
||||
def test_getitems_folder(self, populated_config):
|
||||
nodes = populated_config.getitems(["@office"])
|
||||
assert "server1@office" in nodes
|
||||
|
||||
|
||||
class TestGetAll:
|
||||
def test_getallnodes_no_filter(self, populated_config):
|
||||
nodes = populated_config._getallnodes()
|
||||
assert "router1" in nodes
|
||||
assert "server1@office" in nodes
|
||||
assert "db1@datacenter@office" in nodes
|
||||
|
||||
def test_getallnodes_string_filter(self, populated_config):
|
||||
nodes = populated_config._getallnodes("router.*")
|
||||
assert "router1" in nodes
|
||||
assert "server1@office" not in nodes
|
||||
|
||||
def test_getallnodes_list_filter(self, populated_config):
|
||||
nodes = populated_config._getallnodes(["router.*", "db.*"])
|
||||
assert "router1" in nodes
|
||||
assert "db1@datacenter@office" in nodes
|
||||
assert "server1@office" not in nodes
|
||||
|
||||
def test_getallnodes_filter_invalid_type(self, populated_config):
|
||||
with pytest.raises(ValueError):
|
||||
populated_config._getallnodes(123)
|
||||
|
||||
def test_getallfolders(self, populated_config):
|
||||
folders = populated_config._getallfolders()
|
||||
assert "@office" in folders
|
||||
assert "@datacenter@office" in folders
|
||||
|
||||
def test_getallnodesfull(self, populated_config):
|
||||
nodes = populated_config._getallnodesfull()
|
||||
assert "router1" in nodes
|
||||
assert nodes["router1"]["host"] == "10.0.0.1"
|
||||
|
||||
def test_getallnodesfull_with_filter(self, populated_config):
|
||||
nodes = populated_config._getallnodesfull("router.*")
|
||||
assert "router1" in nodes
|
||||
assert "server1@office" not in nodes
|
||||
|
||||
def test_profileused(self, tmp_config_dir):
|
||||
"""Detects nodes using a specific profile."""
|
||||
config_file = tmp_config_dir / "config.json"
|
||||
data = {
|
||||
"config": {"case": False, "idletime": 30, "fzf": False},
|
||||
"connections": {
|
||||
"router1": {
|
||||
"host": "10.0.0.1", "protocol": "ssh", "port": "",
|
||||
"user": "@myprofile", "password": "pass",
|
||||
"options": "", "logs": "", "tags": "", "jumphost": "",
|
||||
"type": "connection"
|
||||
},
|
||||
"router2": {
|
||||
"host": "10.0.0.2", "protocol": "ssh", "port": "",
|
||||
"user": "admin", "password": "pass",
|
||||
"options": "", "logs": "", "tags": "", "jumphost": "",
|
||||
"type": "connection"
|
||||
}
|
||||
},
|
||||
"profiles": {
|
||||
"default": {"host": "", "protocol": "ssh", "port": "",
|
||||
"user": "", "password": "", "options": "",
|
||||
"logs": "", "tags": "", "jumphost": ""},
|
||||
"myprofile": {"host": "", "protocol": "ssh", "port": "",
|
||||
"user": "profuser", "password": "profpass",
|
||||
"options": "", "logs": "", "tags": "", "jumphost": ""}
|
||||
}
|
||||
}
|
||||
config_file.write_text(json.dumps(data, indent=4))
|
||||
from connpy.configfile import configfile
|
||||
conf = configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
|
||||
|
||||
used = conf._profileused("myprofile")
|
||||
assert "router1" in used
|
||||
assert "router2" not in used
|
||||
|
||||
def test_saveconfig(self, config):
|
||||
"""Save and reload correctly."""
|
||||
config._connections_add(
|
||||
id="test_node", host="1.2.3.4", protocol="ssh",
|
||||
port="", user="", password="", options="",
|
||||
logs="", tags="", jumphost=""
|
||||
)
|
||||
result = config._saveconfig(config.file)
|
||||
assert result == 0
|
||||
|
||||
# Reload and verify
|
||||
from connpy.configfile import configfile
|
||||
reloaded = configfile(conf=config.file, key=config.key)
|
||||
assert "test_node" in reloaded.connections
|
||||
@@ -0,0 +1,419 @@
|
||||
"""Tests for connpy.core module — node and nodes classes."""
|
||||
import json
|
||||
import os
|
||||
import io
|
||||
import re
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, PropertyMock
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# node.__init__ tests
|
||||
# =========================================================================
|
||||
|
||||
class TestNodeInit:
|
||||
def test_basic_init(self):
|
||||
"""Creates node with basic attributes."""
|
||||
from connpy.core import node
|
||||
n = node("router1", "10.0.0.1", user="admin", password="pass1", protocol="ssh")
|
||||
assert n.unique == "router1"
|
||||
assert n.host == "10.0.0.1"
|
||||
assert n.user == "admin"
|
||||
assert n.protocol == "ssh"
|
||||
assert n.password == ["pass1"]
|
||||
|
||||
def test_default_protocol(self):
|
||||
"""Default protocol is ssh."""
|
||||
from connpy.core import node
|
||||
n = node("router1", "10.0.0.1")
|
||||
assert n.protocol == "ssh"
|
||||
|
||||
def test_password_as_list_of_profiles(self, populated_config):
|
||||
"""Password list with @profile references resolves correctly."""
|
||||
from connpy.core import node
|
||||
n = node("router1", "10.0.0.1", password=["@office-user"],
|
||||
config=populated_config)
|
||||
assert n.password == ["officepass"]
|
||||
|
||||
def test_password_plain_string(self):
|
||||
"""Plain string password is wrapped in a list."""
|
||||
from connpy.core import node
|
||||
n = node("router1", "10.0.0.1", password="mypass")
|
||||
assert n.password == ["mypass"]
|
||||
|
||||
def test_node_with_profile(self, populated_config):
|
||||
"""Resolves @profile references for user."""
|
||||
from connpy.core import node
|
||||
n = node("test1", "10.0.0.1", user="@office-user", password="plain",
|
||||
config=populated_config)
|
||||
assert n.user == "officeadmin"
|
||||
|
||||
def test_node_tags(self):
|
||||
"""Tags are stored correctly."""
|
||||
from connpy.core import node
|
||||
tags = {"os": "cisco_ios", "prompt": r"Router#"}
|
||||
n = node("router1", "10.0.0.1", tags=tags)
|
||||
assert n.tags["os"] == "cisco_ios"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Command generation tests
|
||||
# =========================================================================
|
||||
|
||||
class TestCommandGeneration:
|
||||
def _make_node(self, **kwargs):
|
||||
from connpy.core import node
|
||||
defaults = {
|
||||
"unique": "test", "host": "10.0.0.1", "protocol": "ssh",
|
||||
"user": "admin", "password": "", "port": "", "options": "",
|
||||
"jumphost": "", "tags": "", "logs": ""
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return node(defaults.pop("unique"), defaults.pop("host"), **defaults)
|
||||
|
||||
def test_ssh_cmd_basic(self):
|
||||
n = self._make_node()
|
||||
cmd = n._get_cmd()
|
||||
assert "ssh" in cmd
|
||||
assert "admin@10.0.0.1" in cmd
|
||||
|
||||
def test_ssh_cmd_port(self):
|
||||
n = self._make_node(port="2222")
|
||||
cmd = n._get_cmd()
|
||||
assert "-p 2222" in cmd
|
||||
|
||||
def test_ssh_cmd_options(self):
|
||||
n = self._make_node(options="-o StrictHostKeyChecking=no")
|
||||
cmd = n._get_cmd()
|
||||
assert "-o StrictHostKeyChecking=no" in cmd
|
||||
|
||||
def test_sftp_cmd_port(self):
|
||||
n = self._make_node(protocol="sftp", port="2222")
|
||||
cmd = n._get_cmd()
|
||||
assert "-P 2222" in cmd # SFTP uses uppercase P
|
||||
|
||||
def test_telnet_cmd(self):
|
||||
n = self._make_node(protocol="telnet", port="23")
|
||||
cmd = n._get_cmd()
|
||||
assert "telnet 10.0.0.1" in cmd
|
||||
assert "23" in cmd
|
||||
|
||||
def test_kubectl_cmd(self):
|
||||
n = self._make_node(protocol="kubectl", host="my-pod", tags={"kube_command": "/bin/sh"})
|
||||
cmd = n._get_cmd()
|
||||
assert "kubectl exec" in cmd
|
||||
assert "my-pod" in cmd
|
||||
assert "/bin/sh" in cmd
|
||||
|
||||
def test_kubectl_cmd_default_command(self):
|
||||
n = self._make_node(protocol="kubectl", host="my-pod")
|
||||
cmd = n._get_cmd()
|
||||
assert "/bin/bash" in cmd
|
||||
|
||||
def test_docker_cmd(self):
|
||||
n = self._make_node(protocol="docker", host="my-container",
|
||||
tags={"docker_command": "/bin/sh"})
|
||||
cmd = n._get_cmd()
|
||||
assert "docker" in cmd
|
||||
assert "my-container" in cmd
|
||||
assert "/bin/sh" in cmd
|
||||
|
||||
def test_invalid_protocol_raises(self):
|
||||
n = self._make_node(protocol="invalid_proto")
|
||||
with pytest.raises(ValueError, match="Invalid protocol"):
|
||||
n._get_cmd()
|
||||
|
||||
def test_ssh_cmd_no_user(self):
|
||||
n = self._make_node(user="")
|
||||
cmd = n._get_cmd()
|
||||
assert "10.0.0.1" in cmd
|
||||
assert "@" not in cmd # No user@ prefix
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Password decryption tests
|
||||
# =========================================================================
|
||||
|
||||
class TestPasswordDecryption:
|
||||
def test_passtx_plaintext(self, config):
|
||||
"""Plaintext passwords pass through unchanged."""
|
||||
from connpy.core import node
|
||||
n = node("test", "10.0.0.1", password="plainpass", config=config)
|
||||
result = n._passtx(["plainpass"])
|
||||
assert result == ["plainpass"]
|
||||
|
||||
def test_passtx_encrypted(self, config):
|
||||
"""Encrypted passwords get decrypted."""
|
||||
from connpy.core import node
|
||||
encrypted = config.encrypt("mysecret")
|
||||
n = node("test", "10.0.0.1", password=encrypted, config=config)
|
||||
result = n._passtx([encrypted])
|
||||
assert result == ["mysecret"]
|
||||
|
||||
def test_passtx_missing_key_raises(self):
|
||||
"""Missing key file raises ValueError."""
|
||||
from connpy.core import node
|
||||
n = node("test", "10.0.0.1", password="pass")
|
||||
# A password formatted as encrypted but no valid key
|
||||
with pytest.raises((ValueError, Exception)):
|
||||
n._passtx(["""b'corrupted_encrypted_data'"""], keyfile="/nonexistent")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Log handling tests
|
||||
# =========================================================================
|
||||
|
||||
class TestLogHandling:
|
||||
def test_logfile_variable_substitution(self):
|
||||
from connpy.core import node
|
||||
n = node("router1", "10.0.0.1", user="admin", protocol="ssh", port="22",
|
||||
logs="/logs/${unique}_${host}_${user}")
|
||||
result = n._logfile()
|
||||
assert result == "/logs/router1_10.0.0.1_admin"
|
||||
|
||||
def test_logfile_date_substitution(self):
|
||||
from connpy.core import node
|
||||
import datetime
|
||||
n = node("router1", "10.0.0.1", logs="/logs/${date '%Y'}")
|
||||
result = n._logfile()
|
||||
assert datetime.datetime.now().strftime("%Y") in result
|
||||
|
||||
def test_logclean_removes_ansi(self):
|
||||
from connpy.core import node
|
||||
n = node("test", "10.0.0.1")
|
||||
dirty = "\x1B[32mgreen text\x1B[0m"
|
||||
clean = n._logclean(dirty, var=True)
|
||||
assert "\x1B" not in clean
|
||||
assert "green text" in clean
|
||||
|
||||
def test_logclean_removes_backspaces(self):
|
||||
from connpy.core import node
|
||||
n = node("test", "10.0.0.1")
|
||||
dirty = "type\bo"
|
||||
clean = n._logclean(dirty, var=True)
|
||||
assert "\b" not in clean
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# run() and test() with mock pexpect
|
||||
# =========================================================================
|
||||
|
||||
class TestNodeRun:
|
||||
def _make_connected_node(self, mock_pexpect_obj, **kwargs):
|
||||
"""Create a node and mock its _connect to succeed."""
|
||||
from connpy.core import node
|
||||
defaults = {
|
||||
"unique": "router1", "host": "10.0.0.1",
|
||||
"protocol": "ssh", "user": "admin", "password": ""
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
n = node(defaults.pop("unique"), defaults.pop("host"), **defaults)
|
||||
return n
|
||||
|
||||
def test_run_returns_output(self, mock_pexpect):
|
||||
"""run() returns string output."""
|
||||
child = mock_pexpect["child"]
|
||||
pexp = mock_pexpect["pexpect"]
|
||||
|
||||
# Simulate: connect succeeds, command runs, prompt found
|
||||
child.expect.return_value = 9 # prompt index for ssh
|
||||
child.logfile_read = None
|
||||
|
||||
from connpy.core import node
|
||||
n = node("router1", "10.0.0.1", user="admin", password="")
|
||||
|
||||
# Mock _connect to return True and set up child
|
||||
with patch.object(n, '_connect', return_value=True):
|
||||
n.child = child
|
||||
log_buffer = io.BytesIO(b"show version\nRouter v1.0\nrouter#")
|
||||
n.mylog = log_buffer
|
||||
child.logfile_read = log_buffer
|
||||
|
||||
with patch.object(n, '_logclean', return_value="Router v1.0"):
|
||||
output = n.run(["show version"])
|
||||
|
||||
assert n.status == 0
|
||||
assert output == "Router v1.0"
|
||||
|
||||
def test_run_status_1_on_failure(self, mock_pexpect):
|
||||
"""Status 1 when connection fails."""
|
||||
from connpy.core import node
|
||||
n = node("router1", "10.0.0.1", user="admin", password="")
|
||||
|
||||
with patch.object(n, '_connect', return_value="Connection failed code: 1\nrefused"):
|
||||
output = n.run(["show version"])
|
||||
|
||||
assert n.status == 1
|
||||
assert "refused" in output
|
||||
|
||||
def test_run_with_variables(self, mock_pexpect):
|
||||
"""Variables get substituted in commands."""
|
||||
child = mock_pexpect["child"]
|
||||
child.expect.return_value = 9
|
||||
|
||||
from connpy.core import node
|
||||
n = node("router1", "10.0.0.1", user="admin", password="")
|
||||
|
||||
sent_commands = []
|
||||
child.sendline.side_effect = lambda cmd: sent_commands.append(cmd)
|
||||
|
||||
with patch.object(n, '_connect', return_value=True):
|
||||
n.child = child
|
||||
n.mylog = io.BytesIO(b"output")
|
||||
with patch.object(n, '_logclean', return_value="output"):
|
||||
n.run(["show ip route {subnet}"], vars={"subnet": "10.0.0.0/24"})
|
||||
|
||||
assert "show ip route 10.0.0.0/24" in sent_commands
|
||||
|
||||
def test_run_saves_to_folder(self, mock_pexpect, tmp_path):
|
||||
"""folder param saves log file."""
|
||||
child = mock_pexpect["child"]
|
||||
child.expect.return_value = 9
|
||||
|
||||
from connpy.core import node
|
||||
n = node("router1", "10.0.0.1", user="admin", password="")
|
||||
|
||||
with patch.object(n, '_connect', return_value=True):
|
||||
n.child = child
|
||||
n.mylog = io.BytesIO(b"log output")
|
||||
with patch.object(n, '_logclean', return_value="log output"):
|
||||
n.run(["show version"], folder=str(tmp_path))
|
||||
|
||||
log_files = list(tmp_path.glob("router1_*.txt"))
|
||||
assert len(log_files) == 1
|
||||
assert "log output" in log_files[0].read_text()
|
||||
|
||||
|
||||
class TestNodeTest:
|
||||
def test_test_returns_dict(self, mock_pexpect):
|
||||
"""test() returns dict of results."""
|
||||
child = mock_pexpect["child"]
|
||||
child.expect.return_value = 0 # prompt found (index 0 in test expects)
|
||||
|
||||
from connpy.core import node
|
||||
n = node("router1", "10.0.0.1", user="admin", password="")
|
||||
|
||||
with patch.object(n, '_connect', return_value=True):
|
||||
n.child = child
|
||||
n.mylog = io.BytesIO(b"1.1.1.1 is up")
|
||||
with patch.object(n, '_logclean', return_value="1.1.1.1 is up"):
|
||||
result = n.test(["ping 1.1.1.1"], "1.1.1.1")
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result.get("1.1.1.1") == True
|
||||
|
||||
def test_test_expected_not_found(self, mock_pexpect):
|
||||
"""Expected text not found returns False."""
|
||||
child = mock_pexpect["child"]
|
||||
child.expect.return_value = 0
|
||||
|
||||
from connpy.core import node
|
||||
n = node("router1", "10.0.0.1", user="admin", password="")
|
||||
|
||||
with patch.object(n, '_connect', return_value=True):
|
||||
n.child = child
|
||||
n.mylog = io.BytesIO(b"some other output")
|
||||
with patch.object(n, '_logclean', return_value="some other output"):
|
||||
result = n.test(["ping 1.1.1.1"], "1.1.1.1")
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result.get("1.1.1.1") == False
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# nodes (parallel) tests
|
||||
# =========================================================================
|
||||
|
||||
class TestNodes:
|
||||
def test_nodes_init(self):
|
||||
"""Creates list of node objects."""
|
||||
from connpy.core import nodes
|
||||
nodes_dict = {
|
||||
"r1": {"host": "10.0.0.1", "user": "admin", "password": ""},
|
||||
"r2": {"host": "10.0.0.2", "user": "admin", "password": ""}
|
||||
}
|
||||
mynodes = nodes(nodes_dict)
|
||||
assert len(mynodes.nodelist) == 2
|
||||
assert hasattr(mynodes, "r1")
|
||||
assert hasattr(mynodes, "r2")
|
||||
|
||||
def test_nodes_run_parallel(self):
|
||||
"""run() executes on all nodes and returns dict."""
|
||||
from connpy.core import nodes
|
||||
|
||||
nodes_dict = {
|
||||
"r1": {"host": "10.0.0.1", "user": "admin", "password": ""},
|
||||
"r2": {"host": "10.0.0.2", "user": "admin", "password": ""}
|
||||
}
|
||||
mynodes = nodes(nodes_dict)
|
||||
|
||||
# Mock run on each node — must set output AND status on the node
|
||||
for n in mynodes.nodelist:
|
||||
original_node = n # capture by value
|
||||
def make_mock(node_ref):
|
||||
def mock_run(commands, **kwargs):
|
||||
node_ref.output = f"output from {node_ref.unique}"
|
||||
node_ref.status = 0
|
||||
return mock_run
|
||||
n.run = make_mock(n)
|
||||
|
||||
result = mynodes.run(["show version"])
|
||||
assert "r1" in result
|
||||
assert "r2" in result
|
||||
|
||||
def test_nodes_splitlist(self):
|
||||
"""_splitlist divides list correctly."""
|
||||
from connpy.core import nodes
|
||||
mynodes = nodes({"r1": {"host": "1.1.1.1", "user": "", "password": ""}})
|
||||
chunks = list(mynodes._splitlist([1, 2, 3, 4, 5], 2))
|
||||
assert chunks == [[1, 2], [3, 4], [5]]
|
||||
|
||||
def test_nodes_run_with_vars(self):
|
||||
"""Variables per node and __global__ work."""
|
||||
from connpy.core import nodes
|
||||
|
||||
nodes_dict = {
|
||||
"r1": {"host": "10.0.0.1", "user": "admin", "password": ""},
|
||||
}
|
||||
mynodes = nodes(nodes_dict)
|
||||
|
||||
captured_vars = {}
|
||||
|
||||
def mock_run(commands, vars=None, **kwargs):
|
||||
captured_vars.update(vars or {})
|
||||
mynodes.r1.output = "ok"
|
||||
mynodes.r1.status = 0
|
||||
|
||||
mynodes.r1.run = mock_run
|
||||
|
||||
variables = {
|
||||
"__global__": {"mask": "255.255.255.0"},
|
||||
"r1": {"ip": "10.0.0.1"}
|
||||
}
|
||||
mynodes.run(["show ip"], vars=variables)
|
||||
assert captured_vars.get("mask") == "255.255.255.0"
|
||||
assert captured_vars.get("ip") == "10.0.0.1"
|
||||
|
||||
def test_nodes_on_complete_callback(self):
|
||||
"""on_complete callback fires per node."""
|
||||
from connpy.core import nodes
|
||||
|
||||
nodes_dict = {
|
||||
"r1": {"host": "10.0.0.1", "user": "admin", "password": ""},
|
||||
}
|
||||
mynodes = nodes(nodes_dict)
|
||||
|
||||
completed = []
|
||||
|
||||
def mock_run(commands, **kwargs):
|
||||
mynodes.r1.output = "done"
|
||||
mynodes.r1.status = 0
|
||||
|
||||
mynodes.r1.run = mock_run
|
||||
|
||||
def on_done(unique, output, status):
|
||||
completed.append(unique)
|
||||
|
||||
mynodes.run(["show version"], on_complete=on_done)
|
||||
assert "r1" in completed
|
||||
@@ -0,0 +1,216 @@
|
||||
"""Tests for connpy.hooks module — MethodHook and ClassHook."""
|
||||
import pytest
|
||||
from connpy.hooks import MethodHook, ClassHook
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# MethodHook Tests
|
||||
# =========================================================================
|
||||
|
||||
class TestMethodHook:
|
||||
def test_basic_call(self):
|
||||
"""Decorated function executes normally."""
|
||||
@MethodHook
|
||||
def add(a, b):
|
||||
return a + b
|
||||
assert add(2, 3) == 5
|
||||
|
||||
def test_pre_hook_modifies_args(self):
|
||||
"""Pre-hook can modify arguments before execution."""
|
||||
@MethodHook
|
||||
def greet(name):
|
||||
return f"Hello {name}"
|
||||
|
||||
def uppercase_hook(name):
|
||||
return (name.upper(),), {}
|
||||
|
||||
greet.register_pre_hook(uppercase_hook)
|
||||
assert greet("world") == "Hello WORLD"
|
||||
|
||||
def test_post_hook_modifies_result(self):
|
||||
"""Post-hook can modify the return value."""
|
||||
@MethodHook
|
||||
def compute(x):
|
||||
return x * 2
|
||||
|
||||
def double_result(*args, **kwargs):
|
||||
return kwargs["result"] * 2
|
||||
|
||||
compute.register_post_hook(double_result)
|
||||
assert compute(5) == 20 # 5*2=10, then 10*2=20
|
||||
|
||||
def test_multiple_pre_hooks_order(self):
|
||||
"""Pre-hooks execute in registration order."""
|
||||
calls = []
|
||||
|
||||
@MethodHook
|
||||
def func(x):
|
||||
return x
|
||||
|
||||
def hook1(x):
|
||||
calls.append("hook1")
|
||||
return (x,), {}
|
||||
|
||||
def hook2(x):
|
||||
calls.append("hook2")
|
||||
return (x,), {}
|
||||
|
||||
func.register_pre_hook(hook1)
|
||||
func.register_pre_hook(hook2)
|
||||
func(1)
|
||||
assert calls == ["hook1", "hook2"]
|
||||
|
||||
def test_multiple_post_hooks_order(self):
|
||||
"""Post-hooks execute in registration order."""
|
||||
calls = []
|
||||
|
||||
@MethodHook
|
||||
def func(x):
|
||||
return x
|
||||
|
||||
def hook1(*args, **kwargs):
|
||||
calls.append("hook1")
|
||||
return kwargs["result"]
|
||||
|
||||
def hook2(*args, **kwargs):
|
||||
calls.append("hook2")
|
||||
return kwargs["result"]
|
||||
|
||||
func.register_post_hook(hook1)
|
||||
func.register_post_hook(hook2)
|
||||
func(1)
|
||||
assert calls == ["hook1", "hook2"]
|
||||
|
||||
def test_pre_hook_exception_continues(self, capsys):
|
||||
"""If a pre-hook raises, the function still executes."""
|
||||
@MethodHook
|
||||
def func(x):
|
||||
return x + 1
|
||||
|
||||
def bad_hook(x):
|
||||
raise RuntimeError("broken hook")
|
||||
|
||||
func.register_pre_hook(bad_hook)
|
||||
# Should not raise — the hook error is printed but execution continues
|
||||
result = func(5)
|
||||
assert result == 6
|
||||
|
||||
def test_post_hook_exception_continues(self, capsys):
|
||||
"""If a post-hook raises, the result is still returned."""
|
||||
@MethodHook
|
||||
def func(x):
|
||||
return x + 1
|
||||
|
||||
def bad_hook(*args, **kwargs):
|
||||
raise RuntimeError("broken post hook")
|
||||
|
||||
func.register_post_hook(bad_hook)
|
||||
result = func(5)
|
||||
assert result == 6
|
||||
|
||||
def test_method_hook_as_instance_method(self):
|
||||
"""MethodHook works as a descriptor on a class."""
|
||||
class MyClass:
|
||||
@MethodHook
|
||||
def double(self, x):
|
||||
return x * 2
|
||||
|
||||
obj = MyClass()
|
||||
assert obj.double(5) == 10
|
||||
|
||||
def test_method_hook_instance_hook_registration(self):
|
||||
"""Can register hooks via instance method access."""
|
||||
class MyClass:
|
||||
@MethodHook
|
||||
def process(self, x):
|
||||
return x
|
||||
|
||||
def add_ten(*args, **kwargs):
|
||||
return kwargs["result"] + 10
|
||||
|
||||
obj = MyClass()
|
||||
obj.process.register_post_hook(add_ten)
|
||||
assert obj.process(5) == 15
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# ClassHook Tests
|
||||
# =========================================================================
|
||||
|
||||
class TestClassHook:
|
||||
def test_creates_instance(self):
|
||||
"""ClassHook still creates instances normally."""
|
||||
@ClassHook
|
||||
class MyClass:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
obj = MyClass(42)
|
||||
assert obj.value == 42
|
||||
|
||||
def test_modify_future_instances(self):
|
||||
"""modify() affects all future instances."""
|
||||
@ClassHook
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def set_x_to_99(instance):
|
||||
instance.x = 99
|
||||
|
||||
MyClass.modify(set_x_to_99)
|
||||
obj = MyClass()
|
||||
assert obj.x == 99
|
||||
|
||||
def test_modify_does_not_affect_past(self):
|
||||
"""modify() does not affect already-created instances."""
|
||||
@ClassHook
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
old_obj = MyClass()
|
||||
|
||||
def set_x_to_99(instance):
|
||||
instance.x = 99
|
||||
|
||||
MyClass.modify(set_x_to_99)
|
||||
assert old_obj.x == 1 # Not affected
|
||||
assert MyClass().x == 99 # New instance IS affected
|
||||
|
||||
def test_instance_modify(self):
|
||||
"""instance.modify() only affects that specific instance."""
|
||||
@ClassHook
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
obj1 = MyClass()
|
||||
obj2 = MyClass()
|
||||
|
||||
obj1.modify(lambda inst: setattr(inst, 'x', 999))
|
||||
assert obj1.x == 999
|
||||
assert obj2.x == 1
|
||||
|
||||
def test_multiple_deferred_hooks(self):
|
||||
"""Multiple modify() calls apply in order."""
|
||||
@ClassHook
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.log = []
|
||||
|
||||
MyClass.modify(lambda inst: inst.log.append("first"))
|
||||
MyClass.modify(lambda inst: inst.log.append("second"))
|
||||
|
||||
obj = MyClass()
|
||||
assert obj.log == ["first", "second"]
|
||||
|
||||
def test_getattr_delegation(self):
|
||||
"""ClassHook delegates attribute access to the wrapped class."""
|
||||
@ClassHook
|
||||
class MyClass:
|
||||
class_var = "hello"
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
assert MyClass.class_var == "hello"
|
||||
@@ -0,0 +1,327 @@
|
||||
"""Tests for connpy.plugins module."""
|
||||
import os
|
||||
import textwrap
|
||||
import pytest
|
||||
from connpy.plugins import Plugins
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: write a plugin script to a file
|
||||
# ---------------------------------------------------------------------------
|
||||
def _write_plugin(path, code):
|
||||
"""Write dedented code to a file."""
|
||||
with open(path, "w") as f:
|
||||
f.write(textwrap.dedent(code))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# verify_script tests
|
||||
# =========================================================================
|
||||
|
||||
class TestVerifyScript:
|
||||
def test_valid_parser_entrypoint(self, tmp_path):
|
||||
p = tmp_path / "good.py"
|
||||
_write_plugin(p, """\
|
||||
import argparse
|
||||
|
||||
class Parser:
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser()
|
||||
|
||||
class Entrypoint:
|
||||
def __init__(self, args, parser, connapp):
|
||||
pass
|
||||
""")
|
||||
plugins = Plugins()
|
||||
assert plugins.verify_script(str(p)) == False
|
||||
|
||||
def test_valid_preload_only(self, tmp_path):
|
||||
p = tmp_path / "preload.py"
|
||||
_write_plugin(p, """\
|
||||
class Preload:
|
||||
def __init__(self, connapp):
|
||||
pass
|
||||
""")
|
||||
plugins = Plugins()
|
||||
assert plugins.verify_script(str(p)) == False
|
||||
|
||||
def test_valid_all_three(self, tmp_path):
|
||||
p = tmp_path / "all.py"
|
||||
_write_plugin(p, """\
|
||||
import argparse
|
||||
|
||||
class Parser:
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser()
|
||||
|
||||
class Entrypoint:
|
||||
def __init__(self, args, parser, connapp):
|
||||
pass
|
||||
|
||||
class Preload:
|
||||
def __init__(self, connapp):
|
||||
pass
|
||||
""")
|
||||
plugins = Plugins()
|
||||
assert plugins.verify_script(str(p)) == False
|
||||
|
||||
def test_parser_without_entrypoint(self, tmp_path):
|
||||
p = tmp_path / "bad.py"
|
||||
_write_plugin(p, """\
|
||||
import argparse
|
||||
|
||||
class Parser:
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser()
|
||||
""")
|
||||
plugins = Plugins()
|
||||
result = plugins.verify_script(str(p))
|
||||
assert result # Should be a truthy error string
|
||||
assert "Entrypoint" in result
|
||||
|
||||
def test_entrypoint_without_parser(self, tmp_path):
|
||||
p = tmp_path / "bad.py"
|
||||
_write_plugin(p, """\
|
||||
class Entrypoint:
|
||||
def __init__(self, args, parser, connapp):
|
||||
pass
|
||||
""")
|
||||
plugins = Plugins()
|
||||
result = plugins.verify_script(str(p))
|
||||
assert result
|
||||
assert "Parser" in result
|
||||
|
||||
def test_no_valid_class(self, tmp_path):
|
||||
p = tmp_path / "empty.py"
|
||||
_write_plugin(p, """\
|
||||
def some_function():
|
||||
pass
|
||||
""")
|
||||
plugins = Plugins()
|
||||
result = plugins.verify_script(str(p))
|
||||
assert result
|
||||
assert "No valid class" in result
|
||||
|
||||
def test_parser_missing_self_parser(self, tmp_path):
|
||||
p = tmp_path / "bad.py"
|
||||
_write_plugin(p, """\
|
||||
class Parser:
|
||||
def __init__(self):
|
||||
self.something = "not parser"
|
||||
|
||||
class Entrypoint:
|
||||
def __init__(self, args, parser, connapp):
|
||||
pass
|
||||
""")
|
||||
plugins = Plugins()
|
||||
result = plugins.verify_script(str(p))
|
||||
assert result
|
||||
assert "self.parser" in result
|
||||
|
||||
def test_entrypoint_wrong_args(self, tmp_path):
|
||||
p = tmp_path / "bad.py"
|
||||
_write_plugin(p, """\
|
||||
import argparse
|
||||
|
||||
class Parser:
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser()
|
||||
|
||||
class Entrypoint:
|
||||
def __init__(self, args):
|
||||
pass
|
||||
""")
|
||||
plugins = Plugins()
|
||||
result = plugins.verify_script(str(p))
|
||||
assert result
|
||||
assert "Entrypoint" in result
|
||||
|
||||
def test_preload_wrong_args(self, tmp_path):
|
||||
p = tmp_path / "bad.py"
|
||||
_write_plugin(p, """\
|
||||
class Preload:
|
||||
def __init__(self, connapp, extra):
|
||||
pass
|
||||
""")
|
||||
plugins = Plugins()
|
||||
result = plugins.verify_script(str(p))
|
||||
assert result
|
||||
assert "Preload" in result
|
||||
|
||||
def test_disallowed_top_level(self, tmp_path):
|
||||
p = tmp_path / "bad.py"
|
||||
_write_plugin(p, """\
|
||||
MY_GLOBAL = "not allowed"
|
||||
|
||||
class Preload:
|
||||
def __init__(self, connapp):
|
||||
pass
|
||||
""")
|
||||
plugins = Plugins()
|
||||
result = plugins.verify_script(str(p))
|
||||
assert result
|
||||
assert "not allowed" in result.lower() or "Plugin can only have" in result
|
||||
|
||||
def test_syntax_error(self, tmp_path):
|
||||
p = tmp_path / "bad.py"
|
||||
_write_plugin(p, """\
|
||||
def broken(
|
||||
""")
|
||||
plugins = Plugins()
|
||||
result = plugins.verify_script(str(p))
|
||||
assert result
|
||||
assert "Syntax error" in result
|
||||
|
||||
def test_if_name_main_allowed(self, tmp_path):
|
||||
p = tmp_path / "good.py"
|
||||
_write_plugin(p, """\
|
||||
class Preload:
|
||||
def __init__(self, connapp):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("standalone")
|
||||
""")
|
||||
plugins = Plugins()
|
||||
assert plugins.verify_script(str(p)) == False
|
||||
|
||||
def test_other_if_not_allowed(self, tmp_path):
|
||||
p = tmp_path / "bad.py"
|
||||
_write_plugin(p, """\
|
||||
import sys
|
||||
|
||||
if sys.platform == "linux":
|
||||
pass
|
||||
|
||||
class Preload:
|
||||
def __init__(self, connapp):
|
||||
pass
|
||||
""")
|
||||
plugins = Plugins()
|
||||
result = plugins.verify_script(str(p))
|
||||
assert result
|
||||
assert "__name__" in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Import and loading tests
|
||||
# =========================================================================
|
||||
|
||||
class TestPluginLoading:
|
||||
def test_import_from_path(self, tmp_path):
|
||||
p = tmp_path / "mymod.py"
|
||||
_write_plugin(p, """\
|
||||
MY_VAR = 42
|
||||
""")
|
||||
plugins = Plugins()
|
||||
module = plugins._import_from_path(str(p))
|
||||
assert module.MY_VAR == 42
|
||||
|
||||
def test_import_plugins_to_argparse(self, tmp_path):
|
||||
"""Valid plugins get loaded into argparse."""
|
||||
import argparse
|
||||
|
||||
plugin_dir = tmp_path / "plugins"
|
||||
plugin_dir.mkdir()
|
||||
_write_plugin(plugin_dir / "myplugin.py", """\
|
||||
import argparse
|
||||
|
||||
class Parser:
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser(description="My plugin")
|
||||
|
||||
class Entrypoint:
|
||||
def __init__(self, args, parser, connapp):
|
||||
pass
|
||||
""")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
|
||||
plugins = Plugins()
|
||||
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
|
||||
|
||||
assert "myplugin" in plugins.plugins
|
||||
assert "myplugin" in plugins.plugin_parsers
|
||||
|
||||
def test_plugin_name_collision(self, tmp_path):
|
||||
"""Plugin with same name as existing subcommand is skipped."""
|
||||
import argparse
|
||||
|
||||
plugin_dir = tmp_path / "plugins"
|
||||
plugin_dir.mkdir()
|
||||
_write_plugin(plugin_dir / "existcmd.py", """\
|
||||
import argparse
|
||||
|
||||
class Parser:
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser()
|
||||
|
||||
class Entrypoint:
|
||||
def __init__(self, args, parser, connapp):
|
||||
pass
|
||||
""")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
subparsers.add_parser("existcmd") # Already taken
|
||||
|
||||
plugins = Plugins()
|
||||
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
|
||||
|
||||
assert "existcmd" not in plugins.plugins
|
||||
|
||||
def test_preload_registration(self, tmp_path):
|
||||
"""Preload class gets registered in preloads dict."""
|
||||
import argparse
|
||||
|
||||
plugin_dir = tmp_path / "plugins"
|
||||
plugin_dir.mkdir()
|
||||
_write_plugin(plugin_dir / "preloader.py", """\
|
||||
class Preload:
|
||||
def __init__(self, connapp):
|
||||
pass
|
||||
""")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
|
||||
plugins = Plugins()
|
||||
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
|
||||
|
||||
assert "preloader" in plugins.preloads
|
||||
|
||||
def test_invalid_plugin_skipped(self, tmp_path, capsys):
|
||||
"""Invalid plugin is skipped with error message."""
|
||||
import argparse
|
||||
|
||||
plugin_dir = tmp_path / "plugins"
|
||||
plugin_dir.mkdir()
|
||||
_write_plugin(plugin_dir / "badplugin.py", """\
|
||||
MY_GLOBAL = "bad"
|
||||
""")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
|
||||
plugins = Plugins()
|
||||
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
|
||||
|
||||
assert "badplugin" not in plugins.plugins
|
||||
captured = capsys.readouterr()
|
||||
assert "Failed to load plugin" in captured.err or "Failed to load plugin" in captured.out
|
||||
|
||||
def test_empty_directory(self, tmp_path):
|
||||
"""Empty directory doesn't cause errors."""
|
||||
import argparse
|
||||
|
||||
plugin_dir = tmp_path / "plugins"
|
||||
plugin_dir.mkdir()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
|
||||
plugins = Plugins()
|
||||
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
|
||||
|
||||
assert len(plugins.plugins) == 0
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Tests for connpy.printer module."""
|
||||
import sys
|
||||
from io import StringIO
|
||||
from connpy import printer
|
||||
|
||||
|
||||
class TestPrinter:
|
||||
def test_info_output(self, capsys):
|
||||
printer.info("hello world")
|
||||
captured = capsys.readouterr()
|
||||
assert "[i] hello world" in captured.out
|
||||
|
||||
def test_success_output(self, capsys):
|
||||
printer.success("done")
|
||||
captured = capsys.readouterr()
|
||||
assert "[✓] done" in captured.out
|
||||
|
||||
def test_warning_output(self, capsys):
|
||||
printer.warning("careful")
|
||||
captured = capsys.readouterr()
|
||||
assert "[!] careful" in captured.out
|
||||
|
||||
def test_error_output(self, capsys):
|
||||
printer.error("failed")
|
||||
captured = capsys.readouterr()
|
||||
assert "[✗] failed" in captured.err
|
||||
|
||||
def test_debug_output(self, capsys):
|
||||
printer.debug("debug info")
|
||||
captured = capsys.readouterr()
|
||||
assert "[d] debug info" in captured.out
|
||||
|
||||
def test_start_output(self, capsys):
|
||||
printer.start("starting")
|
||||
captured = capsys.readouterr()
|
||||
assert "[+] starting" in captured.out
|
||||
|
||||
def test_custom_output(self, capsys):
|
||||
printer.custom("TAG", "custom message")
|
||||
captured = capsys.readouterr()
|
||||
assert "[TAG] custom message" in captured.out
|
||||
|
||||
def test_multiline_indentation(self, capsys):
|
||||
printer.info("line1\nline2\nline3")
|
||||
captured = capsys.readouterr()
|
||||
lines = captured.out.strip().split("\n")
|
||||
assert lines[0] == "[i] line1"
|
||||
# Second line should be indented by len("[i] ") = 4 chars
|
||||
assert lines[1].startswith(" line2")
|
||||
assert lines[2].startswith(" line3")
|
||||
@@ -0,0 +1,5 @@
|
||||
[pytest]
|
||||
testpaths = connpy/tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
Reference in New Issue
Block a user