test: add comprehensive unit tests and fix bare excepts

- Added comprehensive unit test suite covering AI, core, plugins, completion, API, and hooks using pytest.
- Replaced bare 'except:' clauses across the codebase with specific exception handling (e.g., ValueError, KeyError, OSError) to prevent swallowing system exit calls.
- Fixed Python 3.8+ AST compatibility in plugins.py (support for ast.Constant).
- Removed deprecated pkg_resources import from __init__.py.
- Fixed missing 'printer' import in configfile.py that caused NameErrors during save failures.
This commit is contained in:
2026-04-03 17:11:45 -03:00
parent 7de6003435
commit cf95befb43
21 changed files with 2490 additions and 56 deletions
-1
View File
@@ -542,7 +542,6 @@ from .api import *
from .ai import ai from .ai import ai
from .plugins import Plugins from .plugins import Plugins
from ._version import __version__ from ._version import __version__
from pkg_resources import get_distribution
from . import printer from . import printer
__all__ = ["node", "nodes", "configfile", "connapp", "ai", "Plugins", "printer"] __all__ = ["node", "nodes", "configfile", "connapp", "ai", "Plugins", "printer"]
+7 -6
View File
@@ -396,7 +396,7 @@ class ai:
if isinstance(commands, str): if isinstance(commands, str):
try: try:
commands = json.loads(commands) commands = json.loads(commands)
except: except ValueError:
commands = [c.strip() for c in commands.split('\n') if c.strip()] commands = [c.strip() for c in commands.split('\n') if c.strip()]
# Expand multi-line commands within a list (in case the AI packs them) # Expand multi-line commands within a list (in case the AI packs them)
@@ -795,9 +795,10 @@ class ai:
response = completion(model=model, messages=safe_messages, tools=[], api_key=key) response = completion(model=model, messages=safe_messages, tools=[], api_key=key)
resp_msg = response.choices[0].message resp_msg = response.choices[0].message
messages.append(resp_msg.model_dump(exclude_none=True)) messages.append(resp_msg.model_dump(exclude_none=True))
except: except Exception as e:
pass if status:
status.update(f"[bold red]Error fetching summary: {e}[/bold red]")
printer.warning(f"Failed to fetch final summary from LLM: {e}")
except KeyboardInterrupt: except KeyboardInterrupt:
if status: status.update("[bold red]Interrupted! Closing pending tasks...") if status: status.update("[bold red]Interrupted! Closing pending tasks...")
last_msg = messages[-1] last_msg = messages[-1]
@@ -810,7 +811,7 @@ class ai:
response = completion(model=model, messages=safe_messages, tools=tools, api_key=key) response = completion(model=model, messages=safe_messages, tools=tools, api_key=key)
resp_msg = response.choices[0].message resp_msg = response.choices[0].message
messages.append(resp_msg.model_dump(exclude_none=True)) messages.append(resp_msg.model_dump(exclude_none=True))
except: pass except Exception: pass
finally: finally:
try: try:
log_dir = self.config.defaultdir log_dir = self.config.defaultdir
@@ -820,7 +821,7 @@ class ai:
if os.path.exists(log_path): if os.path.exists(log_path):
try: try:
with open(log_path, "r") as f: hist = json.load(f) with open(log_path, "r") as f: hist = json.load(f)
except: hist = [] except (IOError, json.JSONDecodeError): hist = []
hist.append({"timestamp": datetime.datetime.now().isoformat(), "roles": {"strategic_engine": self.architect_model, "execution_engine": self.engineer_model}, "session": messages}) hist.append({"timestamp": datetime.datetime.now().isoformat(), "roles": {"strategic_engine": self.architect_model, "execution_engine": self.engineer_model}, "session": messages})
with open(log_path, "w") as f: json.dump(hist[-10:], f, indent=4) with open(log_path, "w") as f: json.dump(hist[-10:], f, indent=4)
except Exception as e: except Exception as e:
+10 -10
View File
@@ -35,7 +35,7 @@ def list_nodes():
else: else:
filter = filter.lower() filter = filter.lower()
output = conf._getallnodes(filter) output = conf._getallnodes(filter)
except: except Exception:
output = conf._getallnodes() output = conf._getallnodes()
return jsonify(output) return jsonify(output)
@@ -52,7 +52,7 @@ def get_nodes():
else: else:
filter = filter.lower() filter = filter.lower()
output = conf._getallnodesfull(filter) output = conf._getallnodesfull(filter)
except: except Exception:
output = conf._getallnodesfull() output = conf._getallnodesfull()
return jsonify(output) return jsonify(output)
@@ -109,13 +109,13 @@ def run_commands():
mynodes = nodes(mynodes, config=conf) mynodes = nodes(mynodes, config=conf)
try: try:
args["vars"] = data["vars"] args["vars"] = data["vars"]
except: except Exception:
pass pass
try: try:
options = data["options"] options = data["options"]
thisoptions = {k: v for k, v in options.items() if k in ["prompt", "parallel", "timeout"]} thisoptions = {k: v for k, v in options.items() if k in ["prompt", "parallel", "timeout"]}
args.update(thisoptions) args.update(thisoptions)
except: except Exception:
options = None options = None
if action == "run": if action == "run":
output = mynodes.run(**args) output = mynodes.run(**args)
@@ -136,20 +136,20 @@ def stop_api():
pid = int(f.readline().strip()) pid = int(f.readline().strip())
port = int(f.readline().strip()) port = int(f.readline().strip())
PID_FILE=PID_FILE1 PID_FILE=PID_FILE1
except: except (FileNotFoundError, ValueError, OSError):
try: try:
with open(PID_FILE2, "r") as f: with open(PID_FILE2, "r") as f:
pid = int(f.readline().strip()) pid = int(f.readline().strip())
port = int(f.readline().strip()) port = int(f.readline().strip())
PID_FILE=PID_FILE2 PID_FILE=PID_FILE2
except: except (FileNotFoundError, ValueError, OSError):
printer.warning("Connpy API server is not running.") printer.warning("Connpy API server is not running.")
return return
# Send a SIGTERM signal to the process # Send a SIGTERM signal to the process
try: try:
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
except: except OSError as e:
pass printer.warning(f"Process kill failed (maybe already dead): {e}")
# Delete the PID file # Delete the PID file
os.remove(PID_FILE) os.remove(PID_FILE)
printer.info(f"Server with process ID {pid} stopped.") printer.info(f"Server with process ID {pid} stopped.")
@@ -177,11 +177,11 @@ def start_api(port=8048):
try: try:
with open(PID_FILE1, "w") as f: with open(PID_FILE1, "w") as f:
f.write(str(pid) + "\n" + str(port)) f.write(str(pid) + "\n" + str(port))
except: except OSError:
try: try:
with open(PID_FILE2, "w") as f: with open(PID_FILE2, "w") as f:
f.write(str(pid) + "\n" + str(port)) f.write(str(pid) + "\n" + str(port))
except: except OSError:
printer.error("Couldn't create PID file.") printer.error("Couldn't create PID file.")
exit(1) exit(1)
printer.start(f"Server is running with process ID {pid} on port {port}") printer.start(f"Server is running with process ID {pid} on port {port}")
+2 -2
View File
@@ -95,7 +95,7 @@ def main():
try: try:
with open(pathfile, "r") as f: with open(pathfile, "r") as f:
configdir = f.read().strip() configdir = f.read().strip()
except: except (FileNotFoundError, IOError):
configdir = defaultdir configdir = defaultdir
defaultfile = configdir + '/config.json' defaultfile = configdir + '/config.json'
jsonconf = open(defaultfile) jsonconf = open(defaultfile)
@@ -131,7 +131,7 @@ def main():
spec.loader.exec_module(module) spec.loader.exec_module(module)
plugin_completion = getattr(module, "_connpy_completion") plugin_completion = getattr(module, "_connpy_completion")
strings = plugin_completion(wordsnumber, words, info) strings = plugin_completion(wordsnumber, words, info)
except: except Exception:
exit() exit()
elif wordsnumber >= 3 and words[0] == "ai": elif wordsnumber >= 3 and words[0] == "ai":
if wordsnumber == 3: if wordsnumber == 3:
+10 -8
View File
@@ -8,6 +8,7 @@ from Crypto.Cipher import PKCS1_OAEP
from pathlib import Path from pathlib import Path
from copy import deepcopy from copy import deepcopy
from .hooks import MethodHook, ClassHook from .hooks import MethodHook, ClassHook
from . import printer
@@ -60,7 +61,7 @@ class configfile:
try: try:
with open(pathfile, "r") as f: with open(pathfile, "r") as f:
configdir = f.read().strip() configdir = f.read().strip()
except: except (FileNotFoundError, IOError):
with open(pathfile, "w") as f: with open(pathfile, "w") as f:
f.write(str(defaultdir)) f.write(str(defaultdir))
configdir = defaultdir configdir = defaultdir
@@ -120,7 +121,8 @@ class configfile:
with open(conf, "w") as f: with open(conf, "w") as f:
json.dump(newconfig, f, indent = 4) json.dump(newconfig, f, indent = 4)
f.close() f.close()
except: except (IOError, OSError) as e:
printer.error(f"Failed to save config: {e}")
return 1 return 1
return 0 return 0
@@ -205,12 +207,12 @@ class configfile:
if profile: if profile:
try: try:
newfolder[node_name][key] = self.profiles[profile.group(1)][key] newfolder[node_name][key] = self.profiles[profile.group(1)][key]
except: except KeyError:
newfolder[node_name][key] = "" newfolder[node_name][key] = ""
elif value == '' and key == "protocol": elif value == '' and key == "protocol":
try: try:
newfolder[node_name][key] = self.profiles["default"][key] newfolder[node_name][key] = self.profiles["default"][key]
except: except KeyError:
newfolder[node_name][key] = "ssh" newfolder[node_name][key] = "ssh"
newfolder = {"{}{}".format(k,unique):v for k,v in newfolder.items()} newfolder = {"{}{}".format(k,unique):v for k,v in newfolder.items()}
@@ -231,12 +233,12 @@ class configfile:
if profile: if profile:
try: try:
newnode[key] = self.profiles[profile.group(1)][key] newnode[key] = self.profiles[profile.group(1)][key]
except: except KeyError:
newnode[key] = "" newnode[key] = ""
elif value == '' and key == "protocol": elif value == '' and key == "protocol":
try: try:
newnode[key] = self.profiles["default"][key] newnode[key] = self.profiles["default"][key]
except: except KeyError:
newnode[key] = "ssh" newnode[key] = "ssh"
return newnode return newnode
@@ -391,12 +393,12 @@ class configfile:
if profile: if profile:
try: try:
nodes[node][key] = self.profiles[profile.group(1)][key] nodes[node][key] = self.profiles[profile.group(1)][key]
except: except KeyError:
nodes[node][key] = "" nodes[node][key] = ""
elif value == '' and key == "protocol": elif value == '' and key == "protocol":
try: try:
nodes[node][key] = self.profiles["default"][key] nodes[node][key] = self.profiles["default"][key]
except: except KeyError:
nodes[node][key] = "ssh" nodes[node][key] = "ssh"
return nodes return nodes
+16 -17
View File
@@ -17,7 +17,6 @@ import shutil
class NoAliasDumper(yaml.SafeDumper): class NoAliasDumper(yaml.SafeDumper):
def ignore_aliases(self, data): def ignore_aliases(self, data):
return True return True
import ast
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.console import Console, Group from rich.console import Console, Group
from rich.panel import Panel from rich.panel import Panel
@@ -29,7 +28,7 @@ mdprint = Console().print
console = Console() console = Console()
try: try:
from pyfzf.pyfzf import FzfPrompt from pyfzf.pyfzf import FzfPrompt
except: except ImportError:
FzfPrompt = None FzfPrompt = None
@@ -64,7 +63,7 @@ class connapp:
self.case = self.config.config["case"] self.case = self.config.config["case"]
try: try:
self.fzf = self.config.config["fzf"] self.fzf = self.config.config["fzf"]
except: except KeyError:
self.fzf = False self.fzf = False
@@ -178,13 +177,13 @@ class connapp:
try: try:
core_path = os.path.dirname(os.path.realpath(__file__)) + "/core_plugins" core_path = os.path.dirname(os.path.realpath(__file__)) + "/core_plugins"
self.plugins._import_plugins_to_argparse(core_path, subparsers) self.plugins._import_plugins_to_argparse(core_path, subparsers)
except: except Exception as e:
pass printer.warning(e)
try: try:
file_path = self.config.defaultdir + "/plugins" file_path = self.config.defaultdir + "/plugins"
self.plugins._import_plugins_to_argparse(file_path, subparsers) self.plugins._import_plugins_to_argparse(file_path, subparsers)
except: except Exception as e:
pass printer.warning(e)
for preload in self.plugins.preloads.values(): for preload in self.plugins.preloads.values():
preload.Preload(self) preload.Preload(self)
#Generate helps #Generate helps
@@ -826,7 +825,7 @@ class connapp:
try: try:
with open(args.data[0]) as file: with open(args.data[0]) as file:
imported = yaml.load(file, Loader=yaml.FullLoader) imported = yaml.load(file, Loader=yaml.FullLoader)
except: except Exception:
printer.error("failed reading file {}".format(args.data[0])) printer.error("failed reading file {}".format(args.data[0]))
exit(10) exit(10)
for k,v in imported.items(): for k,v in imported.items():
@@ -1013,7 +1012,7 @@ class connapp:
try: try:
with open(args.data[0]) as file: with open(args.data[0]) as file:
scripts = yaml.load(file, Loader=yaml.FullLoader) scripts = yaml.load(file, Loader=yaml.FullLoader)
except: except Exception:
printer.error("failed reading file {}".format(args.data[0])) printer.error("failed reading file {}".format(args.data[0]))
exit(10) exit(10)
for script in scripts["tasks"]: for script in scripts["tasks"]:
@@ -1053,13 +1052,13 @@ class connapp:
options = script["options"] options = script["options"]
thisoptions = {k: v for k, v in options.items() if k in ["prompt", "parallel", "timeout"]} thisoptions = {k: v for k, v in options.items() if k in ["prompt", "parallel", "timeout"]}
args.update(thisoptions) args.update(thisoptions)
except: except KeyError:
options = None options = None
try: try:
size = str(os.get_terminal_size()) size = str(os.get_terminal_size())
p = re.search(r'.*columns=([0-9]+)', size) p = re.search(r'.*columns=([0-9]+)', size)
columns = int(p.group(1)) columns = int(p.group(1))
except: except (ValueError, OSError):
columns = 80 columns = 80
PANEL_WIDTH = columns PANEL_WIDTH = columns
@@ -1182,7 +1181,7 @@ class connapp:
raise inquirer.errors.ValidationError("", reason="Pick a port between 1-65535, @profile o leave empty") raise inquirer.errors.ValidationError("", reason="Pick a port between 1-65535, @profile o leave empty")
try: try:
port = int(current) port = int(current)
except: except ValueError:
port = 0 port = 0
if current != "" and not 1 <= int(port) <= 65535: if current != "" and not 1 <= int(port) <= 65535:
raise inquirer.errors.ValidationError("", reason="Pick a port between 1-65535 or leave empty") raise inquirer.errors.ValidationError("", reason="Pick a port between 1-65535 or leave empty")
@@ -1194,7 +1193,7 @@ class connapp:
raise inquirer.errors.ValidationError("", reason="Pick a port between 1-6553/app5, @profile or leave empty") raise inquirer.errors.ValidationError("", reason="Pick a port between 1-6553/app5, @profile or leave empty")
try: try:
port = int(current) port = int(current)
except: except ValueError:
port = 0 port = 0
if current.startswith("@"): if current.startswith("@"):
if current[1:] not in self.profiles: if current[1:] not in self.profiles:
@@ -1220,7 +1219,7 @@ class connapp:
isdict = False isdict = False
try: try:
isdict = ast.literal_eval(current) isdict = ast.literal_eval(current)
except: except Exception:
pass pass
if not isinstance (isdict, dict): if not isinstance (isdict, dict):
raise inquirer.errors.ValidationError("", reason="Tags should be a python dictionary.".format(current)) raise inquirer.errors.ValidationError("", reason="Tags should be a python dictionary.".format(current))
@@ -1232,7 +1231,7 @@ class connapp:
isdict = False isdict = False
try: try:
isdict = ast.literal_eval(current) isdict = ast.literal_eval(current)
except: except Exception:
pass pass
if not isinstance (isdict, dict): if not isinstance (isdict, dict):
raise inquirer.errors.ValidationError("", reason="Tags should be a python dictionary.".format(current)) raise inquirer.errors.ValidationError("", reason="Tags should be a python dictionary.".format(current))
@@ -1316,7 +1315,7 @@ class connapp:
defaults["tags"] = "" defaults["tags"] = ""
if "jumphost" not in defaults: if "jumphost" not in defaults:
defaults["jumphost"] = "" defaults["jumphost"] = ""
except: except KeyError:
defaults = { "host":"", "protocol":"", "port":"", "user":"", "options":"", "logs":"" , "tags":"", "password":"", "jumphost":""} defaults = { "host":"", "protocol":"", "port":"", "user":"", "options":"", "logs":"" , "tags":"", "password":"", "jumphost":""}
node = {} node = {}
if edit == None: if edit == None:
@@ -1390,7 +1389,7 @@ class connapp:
defaults["tags"] = "" defaults["tags"] = ""
if "jumphost" not in defaults: if "jumphost" not in defaults:
defaults["jumphost"] = "" defaults["jumphost"] = ""
except: except KeyError:
defaults = { "host":"", "protocol":"", "port":"", "user":"", "options":"", "logs":"", "tags": "", "jumphost": ""} defaults = { "host":"", "protocol":"", "port":"", "user":"", "options":"", "logs":"", "tags": "", "jumphost": ""}
profile = {} profile = {}
if edit == None: if edit == None:
+5 -5
View File
@@ -84,12 +84,12 @@ class node:
if profile and config != '': if profile and config != '':
try: try:
setattr(self,key,config.profiles[profile.group(1)][key]) setattr(self,key,config.profiles[profile.group(1)][key])
except: except KeyError:
setattr(self,key,"") setattr(self,key,"")
elif attr[key] == '' and key == "protocol": elif attr[key] == '' and key == "protocol":
try: try:
setattr(self,key,config.profiles["default"][key]) setattr(self,key,config.profiles["default"][key])
except: except (KeyError, AttributeError):
setattr(self,key,"ssh") setattr(self,key,"ssh")
else: else:
setattr(self,key,attr[key]) setattr(self,key,attr[key])
@@ -108,12 +108,12 @@ class node:
if profile: if profile:
try: try:
self.jumphost[key] = config.profiles[profile.group(1)][key] self.jumphost[key] = config.profiles[profile.group(1)][key]
except: except KeyError:
self.jumphost[key] = "" self.jumphost[key] = ""
elif self.jumphost[key] == '' and key == "protocol": elif self.jumphost[key] == '' and key == "protocol":
try: try:
self.jumphost[key] = config.profiles["default"][key] self.jumphost[key] = config.profiles["default"][key]
except: except KeyError:
self.jumphost[key] = "ssh" self.jumphost[key] = "ssh"
if isinstance(self.jumphost["password"],list): if isinstance(self.jumphost["password"],list):
jumphost_password = [] jumphost_password = []
@@ -158,7 +158,7 @@ class node:
try: try:
decrypted = decryptor.decrypt(ast.literal_eval(passwd)).decode("utf-8") decrypted = decryptor.decrypt(ast.literal_eval(passwd)).decode("utf-8")
dpass.append(decrypted) dpass.append(decrypted)
except: except Exception:
raise ValueError("Missing or corrupted key") raise ValueError("Missing or corrupted key")
return dpass return dpass
+3 -3
View File
@@ -176,7 +176,7 @@ class RemoteCapture:
printer.success("Tcpdump finished capturing packets.") printer.success("Tcpdump finished capturing packets.")
self.listener_active = False self.listener_active = False
except: except Exception:
pass pass
def _sendline_until_connected(self, cmd, retries=5, interval=2): def _sendline_until_connected(self, cmd, retries=5, interval=2):
@@ -307,7 +307,7 @@ class RemoteCapture:
try: try:
self.fake_connection = True self.fake_connection = True
socket.create_connection(("localhost", self.local_port), timeout=1).close() socket.create_connection(("localhost", self.local_port), timeout=1).close()
except: except OSError:
pass pass
self.listener_active = False self.listener_active = False
return return
@@ -324,7 +324,7 @@ class RemoteCapture:
try: try:
self.listener_conn.shutdown(socket.SHUT_RDWR) self.listener_conn.shutdown(socket.SHUT_RDWR)
self.listener_conn.close() self.listener_conn.close()
except: except OSError:
pass pass
if hasattr(self.node, "child"): if hasattr(self.node, "child"):
self.node.child.close(force=True) self.node.child.close(force=True)
+2 -2
View File
@@ -28,7 +28,7 @@ class sync:
self.connapp = connapp self.connapp = connapp
try: try:
self.sync = self.connapp.config.config["sync"] self.sync = self.connapp.config.config["sync"]
except: except KeyError:
self.sync = False self.sync = False
def login(self): def login(self):
@@ -322,7 +322,7 @@ class sync:
def config_listener_pre(self, *args, **kwargs): def config_listener_pre(self, *args, **kwargs):
try: try:
self.sync = self.connapp.config.config["sync"] self.sync = self.connapp.config.config["sync"]
except: except KeyError:
self.sync = False self.sync = False
return args, kwargs return args, kwargs
+2 -2
View File
@@ -62,8 +62,8 @@ class Plugins:
if not (isinstance(node.test, ast.Compare) and if not (isinstance(node.test, ast.Compare) and
isinstance(node.test.left, ast.Name) and isinstance(node.test.left, ast.Name) and
node.test.left.id == '__name__' and node.test.left.id == '__name__' and
isinstance(node.test.comparators[0], ast.Str) and ((hasattr(ast, 'Str') and isinstance(node.test.comparators[0], getattr(ast, 'Str')) and node.test.comparators[0].s == '__main__') or
node.test.comparators[0].s == '__main__'): (hasattr(ast, 'Constant') and isinstance(node.test.comparators[0], getattr(ast, 'Constant')) and node.test.comparators[0].value == '__main__'))):
return "Only __name__ == __main__ If is allowed" return "Only __name__ == __main__ If is allowed"
elif not isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Import, ast.ImportFrom, ast.Pass)): elif not isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Import, ast.ImportFrom, ast.Pass)):
+1
View File
@@ -0,0 +1 @@
# Tests package
+192
View File
@@ -0,0 +1,192 @@
"""Shared fixtures for connpy tests.
All tests use tmp_path to create isolated config/keys.
No test touches ~/.config/conn/
"""
import pytest
import json
import os
from unittest.mock import patch, MagicMock
from Crypto.PublicKey import RSA
# ---------------------------------------------------------------------------
# Minimal config data
# ---------------------------------------------------------------------------
DEFAULT_CONFIG = {
"config": {"case": False, "idletime": 30, "fzf": False},
"connections": {},
"profiles": {
"default": {
"host": "", "protocol": "ssh", "port": "", "user": "",
"password": "", "options": "", "logs": "", "tags": "", "jumphost": ""
}
}
}
SAMPLE_CONNECTIONS = {
"router1": {
"host": "10.0.0.1", "protocol": "ssh", "port": "22",
"user": "admin", "password": "pass1", "options": "",
"logs": "", "tags": "", "jumphost": "", "type": "connection"
},
"office": {
"type": "folder",
"server1": {
"host": "10.0.1.1", "protocol": "ssh", "port": "",
"user": "root", "password": "pass2", "options": "",
"logs": "", "tags": "", "jumphost": "", "type": "connection"
},
"datacenter": {
"type": "subfolder",
"db1": {
"host": "10.0.2.1", "protocol": "ssh", "port": "",
"user": "dbadmin", "password": "pass3", "options": "",
"logs": "", "tags": "", "jumphost": "", "type": "connection"
}
}
}
}
SAMPLE_PROFILES = {
"default": {
"host": "", "protocol": "ssh", "port": "", "user": "",
"password": "", "options": "", "logs": "", "tags": "", "jumphost": ""
},
"office-user": {
"host": "", "protocol": "ssh", "port": "", "user": "officeadmin",
"password": "officepass", "options": "", "logs": "", "tags": "", "jumphost": ""
}
}
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def tmp_config_dir(tmp_path):
"""Create an isolated config directory with config.json and RSA key."""
config_dir = tmp_path / ".config" / "conn"
config_dir.mkdir(parents=True)
plugins_dir = config_dir / "plugins"
plugins_dir.mkdir()
# Write config.json
config_file = config_dir / "config.json"
config_file.write_text(json.dumps(DEFAULT_CONFIG, indent=4))
os.chmod(str(config_file), 0o600)
# Write .folder (points to itself)
folder_file = config_dir / ".folder"
folder_file.write_text(str(config_dir))
# Generate RSA key
key = RSA.generate(2048)
key_file = config_dir / ".osk"
key_file.write_bytes(key.export_key("PEM"))
os.chmod(str(key_file), 0o600)
return config_dir
@pytest.fixture
def config(tmp_config_dir):
"""Create a configfile instance pointing to tmp directory."""
from connpy.configfile import configfile
conf_path = str(tmp_config_dir / "config.json")
key_path = str(tmp_config_dir / ".osk")
return configfile(conf=conf_path, key=key_path)
@pytest.fixture
def populated_config(tmp_config_dir):
"""Create a configfile with sample nodes/profiles pre-loaded."""
config_file = tmp_config_dir / "config.json"
data = {
"config": {"case": False, "idletime": 30, "fzf": False},
"connections": SAMPLE_CONNECTIONS,
"profiles": SAMPLE_PROFILES
}
config_file.write_text(json.dumps(data, indent=4))
from connpy.configfile import configfile
return configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
@pytest.fixture
def mock_pexpect():
"""Mock pexpect.spawn for connection tests."""
with patch("connpy.core.pexpect") as mock_pexp:
child = MagicMock()
child.before = b""
child.after = b"router#"
child.readline.return_value = b""
child.child_fd = 3
mock_pexp.spawn.return_value = child
mock_pexp.EOF = object()
mock_pexp.TIMEOUT = object()
# Also mock fdpexpect
with patch("connpy.core.fdpexpect", create=True) as mock_fd:
mock_fd.fdspawn.return_value = MagicMock()
yield {
"pexpect": mock_pexp,
"child": child,
"fdpexpect": mock_fd
}
@pytest.fixture
def mock_litellm():
"""Mock litellm.completion for AI tests."""
with patch("connpy.ai.completion") as mock_comp:
# Create a default response
msg = MagicMock()
msg.content = "Test response from AI"
msg.tool_calls = None
msg.role = "assistant"
msg.model_dump.return_value = {
"role": "assistant",
"content": "Test response from AI"
}
choice = MagicMock()
choice.message = msg
response = MagicMock()
response.choices = [choice]
response.usage = MagicMock()
response.usage.prompt_tokens = 100
response.usage.completion_tokens = 50
response.usage.total_tokens = 150
mock_comp.return_value = response
yield {
"completion": mock_comp,
"response": response,
"message": msg,
"choice": choice
}
@pytest.fixture
def ai_config(tmp_config_dir):
"""Create a configfile with AI keys configured for AI tests."""
config_file = tmp_config_dir / "config.json"
data = {
"config": {
"case": False, "idletime": 30, "fzf": False,
"ai": {
"engineer_model": "test/test-model",
"engineer_api_key": "test-engineer-key",
"architect_model": "test/test-architect",
"architect_api_key": "test-architect-key"
}
},
"connections": SAMPLE_CONNECTIONS,
"profiles": SAMPLE_PROFILES
}
config_file.write_text(json.dumps(data, indent=4))
from connpy.configfile import configfile
return configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
+397
View File
@@ -0,0 +1,397 @@
"""Tests for connpy.ai module."""
import json
import os
import pytest
from unittest.mock import patch, MagicMock
# =========================================================================
# AI Init tests
# =========================================================================
class TestAIInit:
def test_init_with_keys(self, ai_config, mock_litellm):
"""Initializes correctly when keys are configured."""
from connpy.ai import ai
myai = ai(ai_config)
assert myai.engineer_model == "test/test-model"
assert myai.architect_model == "test/test-architect"
def test_init_missing_engineer_key(self, config):
"""Raises ValueError if engineer key is missing."""
from connpy.ai import ai
with pytest.raises(ValueError, match="Engineer API key"):
ai(config)
def test_init_missing_architect_key_warns(self, ai_config, capsys, mock_litellm):
"""Warns if architect key is missing but doesn't crash."""
# Remove architect key
ai_config.config["ai"]["architect_api_key"] = None
from connpy.ai import ai
# Should not raise
myai = ai(ai_config)
assert myai.architect_key is None
def test_default_models(self, config):
"""Default models are set correctly when not configured."""
config.config["ai"] = {"engineer_api_key": "test-key", "architect_api_key": "test-key"}
from connpy.ai import ai
myai = ai(config)
assert "gemini" in myai.engineer_model.lower()
assert "claude" in myai.architect_model.lower() or "anthropic" in myai.architect_model.lower()
def test_init_loads_memory(self, ai_config, tmp_path, mock_litellm):
"""Loads long-term memory from file if it exists."""
memory_path = os.path.expanduser("~/.config/conn/ai_memory.md")
from connpy.ai import ai
with patch("os.path.exists", side_effect=lambda p: True if p == memory_path else os.path.exists(p)):
with patch("builtins.open", side_effect=lambda f, *a, **kw: (
__import__("io").StringIO("## Memory\nRouter1 is border router")
if f == memory_path else open(f, *a, **kw)
)):
try:
myai = ai(ai_config)
except Exception:
pass # May fail on other file opens, that's ok
# =========================================================================
# register_ai_tool tests
# =========================================================================
class TestRegisterAITool:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def _make_tool_def(self, name="my_tool"):
return {
"type": "function",
"function": {
"name": name,
"description": "Test tool",
"parameters": {"type": "object", "properties": {}}
}
}
def test_register_tool_engineer(self, myai):
tool_def = self._make_tool_def()
myai.register_ai_tool(tool_def, lambda self, **kw: "ok", target="engineer")
assert len(myai.external_engineer_tools) == 1
assert len(myai.external_architect_tools) == 0
def test_register_tool_architect(self, myai):
tool_def = self._make_tool_def()
myai.register_ai_tool(tool_def, lambda self, **kw: "ok", target="architect")
assert len(myai.external_architect_tools) == 1
assert len(myai.external_engineer_tools) == 0
def test_register_tool_both(self, myai):
tool_def = self._make_tool_def()
myai.register_ai_tool(tool_def, lambda self, **kw: "ok", target="both")
assert len(myai.external_engineer_tools) == 1
assert len(myai.external_architect_tools) == 1
def test_register_tool_handler(self, myai):
tool_def = self._make_tool_def("custom_tool")
handler = lambda self, **kw: "result"
myai.register_ai_tool(tool_def, handler)
assert "custom_tool" in myai.external_tool_handlers
assert myai.external_tool_handlers["custom_tool"] is handler
def test_register_tool_prompt_extension(self, myai):
tool_def = self._make_tool_def()
myai.register_ai_tool(
tool_def, lambda self, **kw: "ok",
engineer_prompt="- Custom capability",
architect_prompt=" * Custom tool"
)
assert any("Custom capability" in ext for ext in myai.engineer_prompt_extensions)
assert any("Custom tool" in ext for ext in myai.architect_prompt_extensions)
def test_register_tool_status_formatter(self, myai):
tool_def = self._make_tool_def("status_tool")
formatter = lambda args: f"[STATUS] {args}"
myai.register_ai_tool(tool_def, lambda self, **kw: "ok", status_formatter=formatter)
assert "status_tool" in myai.tool_status_formatters
# =========================================================================
# Dynamic prompts tests
# =========================================================================
class TestDynamicPrompts:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def test_engineer_prompt_without_extensions(self, myai):
prompt = myai.engineer_system_prompt
assert "Plugin Capabilities" not in prompt
assert "TECHNICAL EXECUTION ENGINE" in prompt
def test_engineer_prompt_with_extensions(self, myai):
myai.engineer_prompt_extensions.append("- AWS Cloud Auditing")
prompt = myai.engineer_system_prompt
assert "Plugin Capabilities" in prompt
assert "AWS Cloud Auditing" in prompt
def test_architect_prompt_without_extensions(self, myai):
prompt = myai.architect_system_prompt
assert "Plugin Capabilities" not in prompt
assert "STRATEGIC REASONING ENGINE" in prompt
def test_architect_prompt_with_extensions(self, myai):
myai.architect_prompt_extensions.append(" * Custom tool available")
prompt = myai.architect_system_prompt
assert "Plugin Capabilities" in prompt
assert "Custom tool available" in prompt
# =========================================================================
# _sanitize_messages tests
# =========================================================================
class TestSanitizeMessages:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def test_sanitize_empty(self, myai):
assert myai._sanitize_messages([]) == []
def test_sanitize_normal_messages(self, myai):
messages = [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"}
]
result = myai._sanitize_messages(messages)
assert len(result) == 3
def test_sanitize_removes_orphan_tool_calls(self, myai):
"""Tool calls at the end without responses are removed."""
messages = [
{"role": "user", "content": "do something"},
{"role": "assistant", "content": None, "tool_calls": [
{"id": "tc1", "function": {"name": "list_nodes", "arguments": "{}"}}
]}
# No tool response follows!
]
result = myai._sanitize_messages(messages)
assert len(result) == 1 # Only user message
assert result[0]["role"] == "user"
def test_sanitize_removes_orphan_tool_responses(self, myai):
"""Tool responses without preceding tool_calls are removed."""
messages = [
{"role": "user", "content": "hello"},
{"role": "tool", "tool_call_id": "tc1", "name": "list_nodes", "content": "[]"}
]
result = myai._sanitize_messages(messages)
assert len(result) == 1
assert result[0]["role"] == "user"
def test_sanitize_preserves_valid_tool_pairs(self, myai):
"""Valid assistant+tool_calls followed by tool responses are preserved."""
messages = [
{"role": "user", "content": "list nodes"},
{"role": "assistant", "content": None, "tool_calls": [
{"id": "tc1", "function": {"name": "list_nodes", "arguments": "{}"}}
]},
{"role": "tool", "tool_call_id": "tc1", "name": "list_nodes", "content": "[\"r1\"]"},
{"role": "assistant", "content": "Found r1"}
]
result = myai._sanitize_messages(messages)
assert len(result) == 4
# =========================================================================
# _truncate tests
# =========================================================================
class TestTruncate:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def test_truncate_short_text(self, myai):
text = "short text"
assert myai._truncate(text) == text
def test_truncate_long_text(self, myai):
text = "x" * 100000
result = myai._truncate(text)
assert len(result) < 100000
assert "[... OUTPUT TRUNCATED ...]" in result
def test_truncate_custom_limit(self, myai):
text = "x" * 1000
result = myai._truncate(text, limit=500)
assert len(result) < 1000
assert "[... OUTPUT TRUNCATED ...]" in result
def test_truncate_preserves_head_and_tail(self, myai):
text = "HEAD" + "x" * 100000 + "TAIL"
result = myai._truncate(text)
assert result.startswith("HEAD")
assert result.endswith("TAIL")
# =========================================================================
# Tool methods tests
# =========================================================================
class TestToolMethods:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def test_list_nodes_tool_found(self, myai):
result = myai.list_nodes_tool("router.*")
parsed = json.loads(result)
assert "router1" in str(parsed)
def test_list_nodes_tool_not_found(self, myai):
result = myai.list_nodes_tool("nonexistent_pattern_xyz")
assert "No nodes found" in result
def test_get_node_info_masks_password(self, myai):
result = myai.get_node_info_tool("router1")
parsed = json.loads(result)
assert parsed["password"] == "***"
def test_is_safe_command_show(self, myai):
assert myai._is_safe_command("show running-config") == True
assert myai._is_safe_command("show ip int brief") == True
def test_is_safe_command_config(self, myai):
assert myai._is_safe_command("config t") == False
assert myai._is_safe_command("write memory") == False
def test_is_safe_command_ls(self, myai):
assert myai._is_safe_command("ls -la") == True
def test_is_safe_command_ping(self, myai):
assert myai._is_safe_command("ping 10.0.0.1") == True
# =========================================================================
# manage_memory_tool tests
# =========================================================================
class TestManageMemory:
@pytest.fixture
def myai(self, ai_config, mock_litellm, tmp_path):
from connpy.ai import ai
myai = ai(ai_config)
myai.memory_path = str(tmp_path / "ai_memory.md")
return myai
def test_manage_memory_append(self, myai):
result = myai.manage_memory_tool("Router1 is border router", action="append")
assert "successfully" in result.lower()
assert os.path.exists(myai.memory_path)
content = open(myai.memory_path).read()
assert "Router1 is border router" in content
def test_manage_memory_replace(self, myai):
myai.manage_memory_tool("old content", action="append")
myai.manage_memory_tool("new content only", action="replace")
content = open(myai.memory_path).read()
assert "new content only" in content
assert "old content" not in content
def test_manage_memory_empty_content(self, myai):
result = myai.manage_memory_tool("", action="append")
assert "error" in result.lower() or "Error" in result
# =========================================================================
# ask() with mock LLM tests
# =========================================================================
class TestAsk:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def test_ask_basic_response(self, myai, mock_litellm):
result = myai.ask("hello", stream=False)
assert "response" in result
assert "chat_history" in result
assert "usage" in result
assert result["response"] == "Test response from AI"
def test_ask_sticky_brain_engineer(self, myai, mock_litellm):
result = myai.ask("show me the routers", stream=False)
assert result["responder"] == "engineer"
def test_ask_explicit_architect(self, myai, mock_litellm):
result = myai.ask("architect: review the network design", stream=False)
assert result["responder"] == "architect"
def test_ask_returns_usage(self, myai, mock_litellm):
result = myai.ask("test", stream=False)
assert result["usage"]["total"] > 0
def test_ask_with_chat_history(self, myai, mock_litellm):
history = [
{"role": "user", "content": "previous question"},
{"role": "assistant", "content": "previous answer"}
]
result = myai.ask("follow up", chat_history=history, stream=False)
assert result["response"] is not None
# =========================================================================
# _get_engineer_tools / _get_architect_tools tests
# =========================================================================
class TestToolDefinitions:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def test_engineer_tools_include_core(self, myai):
tools = myai._get_engineer_tools()
names = [t["function"]["name"] for t in tools]
assert "list_nodes" in names
assert "run_commands" in names
assert "get_node_info" in names
assert "consult_architect" in names
assert "escalate_to_architect" in names
def test_engineer_tools_include_external(self, myai):
myai.external_engineer_tools.append({
"type": "function",
"function": {"name": "custom_tool", "description": "test", "parameters": {}}
})
tools = myai._get_engineer_tools()
names = [t["function"]["name"] for t in tools]
assert "custom_tool" in names
def test_architect_tools_include_core(self, myai):
tools = myai._get_architect_tools()
names = [t["function"]["name"] for t in tools]
assert "delegate_to_engineer" in names
assert "return_to_engineer" in names
assert "manage_memory_tool" in names
def test_architect_tools_include_external(self, myai):
myai.external_architect_tools.append({
"type": "function",
"function": {"name": "arch_tool", "description": "test", "parameters": {}}
})
tools = myai._get_architect_tools()
names = [t["function"]["name"] for t in tools]
assert "arch_tool" in names
+268
View File
@@ -0,0 +1,268 @@
"""Tests for connpy.api module — Flask routes."""
import json
import pytest
from unittest.mock import patch, MagicMock
@pytest.fixture
def api_client(populated_config):
"""Create a Flask test client with a populated config."""
from connpy.api import app
app.custom_config = populated_config
app.config["TESTING"] = True
with app.test_client() as client:
yield client
# =========================================================================
# Root endpoint
# =========================================================================
class TestRootEndpoint:
def test_root_returns_welcome(self, api_client):
response = api_client.get("/")
data = response.get_json()
assert response.status_code == 200
assert "Welcome" in data["message"]
assert "version" in data
# =========================================================================
# /list_nodes endpoint
# =========================================================================
class TestListNodes:
def test_list_nodes_no_filter(self, api_client):
response = api_client.post("/list_nodes", json={})
data = response.get_json()
assert response.status_code == 200
assert isinstance(data, list)
assert "router1" in data
def test_list_nodes_with_filter(self, api_client):
response = api_client.post("/list_nodes", json={"filter": "router.*"})
data = response.get_json()
assert "router1" in data
assert all("router" in n or "Router" in n for n in data)
def test_list_nodes_case_insensitive(self, api_client):
"""Filter is lowercased when case=false."""
response = api_client.post("/list_nodes", json={"filter": "ROUTER.*"})
data = response.get_json()
# Should still match since the filter gets lowercased
assert isinstance(data, list)
def test_list_nodes_no_body(self, api_client):
"""No body returns all nodes."""
response = api_client.post("/list_nodes",
data="",
content_type="application/json")
data = response.get_json()
assert isinstance(data, list)
# =========================================================================
# /get_nodes endpoint
# =========================================================================
class TestGetNodes:
def test_get_nodes_no_filter(self, api_client):
response = api_client.post("/get_nodes", json={})
data = response.get_json()
assert response.status_code == 200
assert isinstance(data, dict)
assert "router1" in data
def test_get_nodes_with_filter(self, api_client):
response = api_client.post("/get_nodes", json={"filter": "router.*"})
data = response.get_json()
assert "router1" in data
assert "host" in data["router1"]
def test_get_nodes_has_attributes(self, api_client):
response = api_client.post("/get_nodes", json={"filter": "router1"})
data = response.get_json()
if "router1" in data:
assert "host" in data["router1"]
assert "protocol" in data["router1"]
# =========================================================================
# /run_commands endpoint
# =========================================================================
class TestRunCommands:
def test_missing_action(self, api_client):
response = api_client.post("/run_commands", json={
"nodes": "router1",
"commands": ["show version"]
})
data = response.get_json()
assert "DataError" in data
assert "action" in data["DataError"]
def test_missing_nodes(self, api_client):
response = api_client.post("/run_commands", json={
"action": "run",
"commands": ["show version"]
})
data = response.get_json()
assert "DataError" in data
assert "nodes" in data["DataError"]
def test_missing_commands(self, api_client):
response = api_client.post("/run_commands", json={
"action": "run",
"nodes": "router1"
})
data = response.get_json()
assert "DataError" in data
assert "commands" in data["DataError"]
def test_wrong_action(self, api_client):
response = api_client.post("/run_commands", json={
"action": "invalid",
"nodes": "router1",
"commands": ["show version"]
})
data = response.get_json()
assert "DataError" in data
assert "Wrong action" in data["DataError"]
@patch("connpy.api.nodes")
def test_run_action(self, mock_nodes_cls, api_client):
"""action=run executes and returns output."""
mock_instance = MagicMock()
mock_instance.run.return_value = {"router1": "Router v1.0"}
mock_nodes_cls.return_value = mock_instance
response = api_client.post("/run_commands", json={
"action": "run",
"nodes": "router1",
"commands": ["show version"]
})
data = response.get_json()
assert "router1" in data
@patch("connpy.api.nodes")
def test_test_action(self, mock_nodes_cls, api_client):
"""action=test returns result + output."""
mock_instance = MagicMock()
mock_instance.test.return_value = {"router1": {"expected": True}}
mock_instance.output = {"router1": "output text"}
mock_nodes_cls.return_value = mock_instance
response = api_client.post("/run_commands", json={
"action": "test",
"nodes": "router1",
"commands": ["show version"],
"expected": "Router"
})
data = response.get_json()
assert "result" in data
assert "output" in data
@patch("connpy.api.nodes")
def test_run_with_options(self, mock_nodes_cls, api_client):
"""Options get passed through."""
mock_instance = MagicMock()
mock_instance.run.return_value = {"router1": "ok"}
mock_nodes_cls.return_value = mock_instance
response = api_client.post("/run_commands", json={
"action": "run",
"nodes": "router1",
"commands": ["show version"],
"options": {"timeout": 30, "parallel": 5}
})
assert response.status_code == 200
@patch("connpy.api.nodes")
def test_run_folder_nodes(self, mock_nodes_cls, api_client):
"""Nodes with @ prefix are resolved as folders."""
mock_instance = MagicMock()
mock_instance.run.return_value = {"server1@office": "ok"}
mock_nodes_cls.return_value = mock_instance
response = api_client.post("/run_commands", json={
"action": "run",
"nodes": "@office",
"commands": ["ls -la"]
})
assert response.status_code == 200
@patch("connpy.api.nodes")
def test_run_list_nodes(self, mock_nodes_cls, api_client):
"""List of nodes is resolved correctly."""
mock_instance = MagicMock()
mock_instance.run.return_value = {"router1": "ok", "server1@office": "ok"}
mock_nodes_cls.return_value = mock_instance
response = api_client.post("/run_commands", json={
"action": "run",
"nodes": ["router1", "server1@office"],
"commands": ["show version"]
})
assert response.status_code == 200
# =========================================================================
# /ask_ai endpoint
# =========================================================================
class TestAskAI:
@patch("connpy.api.myai")
def test_ask_ai(self, mock_ai_cls, api_client):
mock_instance = MagicMock()
mock_instance.ask.return_value = {"response": "AI says hello"}
mock_ai_cls.return_value = mock_instance
response = api_client.post("/ask_ai", json={
"input": "list my routers"
})
data = response.get_json()
assert data is not None
@patch("connpy.api.myai")
def test_ask_ai_with_dryrun(self, mock_ai_cls, api_client):
mock_instance = MagicMock()
mock_instance.ask.return_value = {"response": "dry run"}
mock_ai_cls.return_value = mock_instance
response = api_client.post("/ask_ai", json={
"input": "test",
"dryrun": True
})
assert response.status_code == 200
@patch("connpy.api.myai")
def test_ask_ai_with_history(self, mock_ai_cls, api_client):
mock_instance = MagicMock()
mock_instance.ask.return_value = {"response": "with history"}
mock_ai_cls.return_value = mock_instance
response = api_client.post("/ask_ai", json={
"input": "follow up",
"chat_history": [
{"role": "user", "content": "previous"},
{"role": "assistant", "content": "answer"}
]
})
assert response.status_code == 200
# =========================================================================
# /confirm endpoint
# =========================================================================
class TestConfirm:
@patch("connpy.api.myai")
def test_confirm(self, mock_ai_cls, api_client):
mock_instance = MagicMock()
mock_instance.confirm.return_value = True
mock_ai_cls.return_value = mock_instance
response = api_client.post("/confirm", json={
"input": "yes"
})
assert response.status_code == 200
+182
View File
@@ -0,0 +1,182 @@
"""Tests for connpy.completion module."""
import os
import json
import pytest
from connpy.completion import _getallnodes, _getallfolders, _getcwd, _get_plugins
# =========================================================================
# _getallnodes tests
# =========================================================================
class TestGetAllNodes:
def test_flat_nodes(self):
"""Nodes without folders."""
config = {
"connections": {
"router1": {"type": "connection"},
"router2": {"type": "connection"}
}
}
nodes = _getallnodes(config)
assert "router1" in nodes
assert "router2" in nodes
def test_nested_nodes(self):
"""Nodes in folders and subfolders have correct format."""
config = {
"connections": {
"router1": {"type": "connection"},
"office": {
"type": "folder",
"server1": {"type": "connection"},
"datacenter": {
"type": "subfolder",
"db1": {"type": "connection"}
}
}
}
}
nodes = _getallnodes(config)
assert "router1" in nodes
assert "server1@office" in nodes
assert "db1@datacenter@office" in nodes
def test_empty_connections(self):
config = {"connections": {}}
nodes = _getallnodes(config)
assert nodes == []
# =========================================================================
# _getallfolders tests
# =========================================================================
class TestGetAllFolders:
def test_basic_folders(self):
config = {
"connections": {
"office": {"type": "folder"},
"home": {"type": "folder"}
}
}
folders = _getallfolders(config)
assert "@office" in folders
assert "@home" in folders
def test_with_subfolders(self):
config = {
"connections": {
"office": {
"type": "folder",
"datacenter": {"type": "subfolder"},
"server1": {"type": "connection"}
}
}
}
folders = _getallfolders(config)
assert "@office" in folders
assert "@datacenter@office" in folders
def test_empty(self):
config = {"connections": {}}
folders = _getallfolders(config)
assert folders == []
# =========================================================================
# _getcwd tests
# =========================================================================
class TestGetCwd:
def test_current_dir(self, tmp_path, monkeypatch):
"""Lists files in current directory."""
monkeypatch.chdir(tmp_path)
(tmp_path / "file1.txt").touch()
(tmp_path / "file2.py").touch()
subdir = tmp_path / "subdir"
subdir.mkdir()
result = _getcwd(["run", "run"], "run")
# Should list files
assert any("file1.txt" in r for r in result)
assert any("subdir/" in r for r in result)
def test_specific_path(self, tmp_path, monkeypatch):
"""Lists files matching a partial path."""
monkeypatch.chdir(tmp_path)
(tmp_path / "script.yaml").touch()
(tmp_path / "script2.yaml").touch()
result = _getcwd(["run", "script"], "run")
assert any("script" in r for r in result)
def test_folder_only(self, tmp_path, monkeypatch):
"""folderonly=True returns only directories."""
monkeypatch.chdir(tmp_path)
(tmp_path / "file.txt").touch()
subdir = tmp_path / "mydir"
subdir.mkdir()
result = _getcwd(["export", "export"], "export", folderonly=True)
files_in_result = [r for r in result if "file.txt" in r]
assert len(files_in_result) == 0
dirs_in_result = [r for r in result if "mydir" in r]
assert len(dirs_in_result) > 0
# =========================================================================
# _get_plugins tests
# =========================================================================
class TestGetPlugins:
def test_get_plugins_disable(self, tmp_path):
"""--disable returns enabled plugins."""
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
(plugin_dir / "active.py").touch()
(plugin_dir / "disabled.py.bkp").touch()
result = _get_plugins("--disable", str(tmp_path))
assert "active" in result
assert "disabled" not in result
def test_get_plugins_enable(self, tmp_path):
"""--enable returns disabled plugins."""
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
(plugin_dir / "active.py").touch()
(plugin_dir / "disabled.py.bkp").touch()
result = _get_plugins("--enable", str(tmp_path))
assert "disabled" in result
assert "active" not in result
def test_get_plugins_del(self, tmp_path):
"""--del returns all plugins."""
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
(plugin_dir / "active.py").touch()
(plugin_dir / "disabled.py.bkp").touch()
result = _get_plugins("--del", str(tmp_path))
assert "active" in result
assert "disabled" in result
def test_get_plugins_all(self, tmp_path):
"""'all' returns dict with paths."""
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
(plugin_dir / "myplugin.py").touch()
result = _get_plugins("all", str(tmp_path))
assert isinstance(result, dict)
assert "myplugin" in result
def test_get_plugins_empty_dir(self, tmp_path):
"""Empty plugins directory returns empty list."""
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
result = _get_plugins("--disable", str(tmp_path))
assert result == []
+376
View File
@@ -0,0 +1,376 @@
"""Tests for connpy.configfile module."""
import json
import os
import re
import pytest
from copy import deepcopy
class TestConfigfileInit:
def test_creates_default_config(self, tmp_config_dir):
"""Creates config.json with defaults when it doesn't exist."""
config_file = tmp_config_dir / "config.json"
config_file.unlink() # Remove existing
key_file = tmp_config_dir / ".osk"
from connpy.configfile import configfile
conf = configfile(conf=str(config_file), key=str(key_file))
assert config_file.exists()
assert conf.config["case"] == False
assert conf.config["idletime"] == 30
assert "default" in conf.profiles
def test_creates_rsa_key(self, tmp_config_dir):
"""Generates RSA key when it doesn't exist."""
key_file = tmp_config_dir / ".osk"
key_file.unlink() # Remove existing
from connpy.configfile import configfile
conf = configfile(conf=str(tmp_config_dir / "config.json"), key=str(key_file))
assert key_file.exists()
assert conf.privatekey is not None
assert conf.publickey is not None
def test_loads_existing_config(self, config):
"""Loads correctly from existing config."""
assert config.config is not None
assert config.connections is not None
assert config.profiles is not None
def test_config_file_permissions(self, tmp_config_dir):
"""Config is created with 0o600 permissions."""
config_file = tmp_config_dir / "config.json"
config_file.unlink()
from connpy.configfile import configfile
configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
stat = os.stat(str(config_file))
assert oct(stat.st_mode & 0o777) == oct(0o600)
def test_custom_paths(self, tmp_path):
"""Accepts custom paths for conf and key."""
config_dir = tmp_path / "custom"
config_dir.mkdir()
(config_dir / "plugins").mkdir()
# Write .folder for the config dir
dot_folder = tmp_path / ".config" / "conn"
dot_folder.mkdir(parents=True, exist_ok=True)
(dot_folder / ".folder").write_text(str(config_dir))
(dot_folder / "plugins").mkdir(exist_ok=True)
conf_path = str(config_dir / "my_config.json")
key_path = str(config_dir / "my_key")
from connpy.configfile import configfile
conf = configfile(conf=conf_path, key=key_path)
assert conf.file == conf_path
assert conf.key == key_path
class TestEncryption:
def test_encrypt_password(self, config):
"""Encrypts and produces b'...' format."""
encrypted = config.encrypt("mysecret")
assert encrypted.startswith("b'") or encrypted.startswith('b"')
def test_encrypt_decrypt_roundtrip(self, config):
"""Encrypt then decrypt returns original."""
from Crypto.PublicKey import RSA
from Crypto.Cipher import PKCS1_OAEP
import ast
original = "super_secret_password"
encrypted = config.encrypt(original)
# Decrypt
with open(config.key) as f:
key = RSA.import_key(f.read())
decryptor = PKCS1_OAEP.new(key)
decrypted = decryptor.decrypt(ast.literal_eval(encrypted)).decode("utf-8")
assert decrypted == original
class TestExplodeUnique:
def test_simple_node(self, config):
result = config._explode_unique("router1")
assert result == {"id": "router1"}
def test_node_with_folder(self, config):
result = config._explode_unique("r1@office")
assert result == {"id": "r1", "folder": "office"}
def test_node_with_subfolder(self, config):
result = config._explode_unique("r1@dc@office")
assert result == {"id": "r1", "folder": "office", "subfolder": "dc"}
def test_folder_only(self, config):
result = config._explode_unique("@office")
assert result == {"folder": "office"}
def test_subfolder_only(self, config):
result = config._explode_unique("@dc@office")
assert result == {"folder": "office", "subfolder": "dc"}
def test_too_deep(self, config):
result = config._explode_unique("a@b@c@d")
assert result == False
def test_empty_folder(self, config):
result = config._explode_unique("a@")
assert result == False
def test_empty_subfolder(self, config):
result = config._explode_unique("a@@office")
assert result == False
class TestCRUDNodes:
def test_add_node_root(self, config):
config._connections_add(
id="router1", host="10.0.0.1", protocol="ssh",
port="22", user="admin", password="pass", options="",
logs="", tags="", jumphost=""
)
assert "router1" in config.connections
assert config.connections["router1"]["host"] == "10.0.0.1"
def test_add_node_folder(self, config):
config._folder_add(folder="office")
config._connections_add(
id="server1", folder="office", host="10.0.1.1",
protocol="ssh", port="", user="root", password="pass",
options="", logs="", tags="", jumphost=""
)
assert "server1" in config.connections["office"]
def test_add_node_subfolder(self, config):
config._folder_add(folder="office")
config._folder_add(folder="office", subfolder="dc")
config._connections_add(
id="db1", folder="office", subfolder="dc", host="10.0.2.1",
protocol="ssh", port="", user="dbadmin", password="pass",
options="", logs="", tags="", jumphost=""
)
assert "db1" in config.connections["office"]["dc"]
def test_del_node_root(self, config):
config._connections_add(
id="router1", host="10.0.0.1", protocol="ssh",
port="", user="", password="", options="",
logs="", tags="", jumphost=""
)
config._connections_del(id="router1")
assert "router1" not in config.connections
def test_del_node_folder(self, config):
config._folder_add(folder="office")
config._connections_add(
id="server1", folder="office", host="10.0.1.1",
protocol="ssh", port="", user="", password="",
options="", logs="", tags="", jumphost=""
)
config._connections_del(id="server1", folder="office")
assert "server1" not in config.connections["office"]
def test_add_folder(self, config):
config._folder_add(folder="office")
assert "office" in config.connections
assert config.connections["office"]["type"] == "folder"
def test_add_subfolder(self, config):
config._folder_add(folder="office")
config._folder_add(folder="office", subfolder="dc")
assert "dc" in config.connections["office"]
assert config.connections["office"]["dc"]["type"] == "subfolder"
def test_del_folder(self, config):
config._folder_add(folder="office")
config._folder_del(folder="office")
assert "office" not in config.connections
def test_del_subfolder(self, config):
config._folder_add(folder="office")
config._folder_add(folder="office", subfolder="dc")
config._folder_del(folder="office", subfolder="dc")
assert "dc" not in config.connections["office"]
class TestCRUDProfiles:
def test_add_profile(self, config):
config._profiles_add(
id="myprofile", host="", protocol="telnet",
port="23", user="user1", password="pass1",
options="", logs="", tags="", jumphost=""
)
assert "myprofile" in config.profiles
assert config.profiles["myprofile"]["protocol"] == "telnet"
def test_del_profile(self, config):
config._profiles_add(
id="temp", host="", protocol="ssh", port="",
user="", password="", options="", logs="", tags="", jumphost=""
)
config._profiles_del(id="temp")
assert "temp" not in config.profiles
def test_default_profile_exists(self, config):
assert "default" in config.profiles
class TestGetItem:
def test_getitem_node(self, populated_config):
node = populated_config.getitem("router1")
assert node["host"] == "10.0.0.1"
assert "type" not in node # type is stripped
def test_getitem_folder(self, populated_config):
nodes = populated_config.getitem("@office")
# Should contain server1@office but NOT datacenter (subfolder)
assert "server1@office" in nodes
assert all("type" not in v for v in nodes.values())
def test_getitem_subfolder(self, populated_config):
nodes = populated_config.getitem("@datacenter@office")
assert "db1@datacenter@office" in nodes
def test_getitem_node_in_folder(self, populated_config):
node = populated_config.getitem("server1@office")
assert node["host"] == "10.0.1.1"
def test_getitem_node_in_subfolder(self, populated_config):
node = populated_config.getitem("db1@datacenter@office")
assert node["host"] == "10.0.2.1"
def test_getitem_with_profile_extraction(self, tmp_config_dir):
"""extract=True resolves @profile references."""
config_file = tmp_config_dir / "config.json"
data = {
"config": {"case": False, "idletime": 30, "fzf": False},
"connections": {
"router1": {
"host": "10.0.0.1", "protocol": "ssh", "port": "",
"user": "@office-user", "password": "@office-user",
"options": "", "logs": "", "tags": "", "jumphost": "",
"type": "connection"
}
},
"profiles": {
"default": {"host": "", "protocol": "ssh", "port": "",
"user": "", "password": "", "options": "",
"logs": "", "tags": "", "jumphost": ""},
"office-user": {"host": "", "protocol": "ssh", "port": "",
"user": "officeadmin", "password": "officepass",
"options": "", "logs": "", "tags": "", "jumphost": ""}
}
}
config_file.write_text(json.dumps(data, indent=4))
from connpy.configfile import configfile
conf = configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
node = conf.getitem("router1", extract=True)
assert node["user"] == "officeadmin"
assert node["password"] == "officepass"
def test_getitems_multiple(self, populated_config):
nodes = populated_config.getitems(["router1", "server1@office"])
assert "router1" in nodes
assert "server1@office" in nodes
def test_getitems_folder(self, populated_config):
nodes = populated_config.getitems(["@office"])
assert "server1@office" in nodes
class TestGetAll:
def test_getallnodes_no_filter(self, populated_config):
nodes = populated_config._getallnodes()
assert "router1" in nodes
assert "server1@office" in nodes
assert "db1@datacenter@office" in nodes
def test_getallnodes_string_filter(self, populated_config):
nodes = populated_config._getallnodes("router.*")
assert "router1" in nodes
assert "server1@office" not in nodes
def test_getallnodes_list_filter(self, populated_config):
nodes = populated_config._getallnodes(["router.*", "db.*"])
assert "router1" in nodes
assert "db1@datacenter@office" in nodes
assert "server1@office" not in nodes
def test_getallnodes_filter_invalid_type(self, populated_config):
with pytest.raises(ValueError):
populated_config._getallnodes(123)
def test_getallfolders(self, populated_config):
folders = populated_config._getallfolders()
assert "@office" in folders
assert "@datacenter@office" in folders
def test_getallnodesfull(self, populated_config):
nodes = populated_config._getallnodesfull()
assert "router1" in nodes
assert nodes["router1"]["host"] == "10.0.0.1"
def test_getallnodesfull_with_filter(self, populated_config):
nodes = populated_config._getallnodesfull("router.*")
assert "router1" in nodes
assert "server1@office" not in nodes
def test_profileused(self, tmp_config_dir):
"""Detects nodes using a specific profile."""
config_file = tmp_config_dir / "config.json"
data = {
"config": {"case": False, "idletime": 30, "fzf": False},
"connections": {
"router1": {
"host": "10.0.0.1", "protocol": "ssh", "port": "",
"user": "@myprofile", "password": "pass",
"options": "", "logs": "", "tags": "", "jumphost": "",
"type": "connection"
},
"router2": {
"host": "10.0.0.2", "protocol": "ssh", "port": "",
"user": "admin", "password": "pass",
"options": "", "logs": "", "tags": "", "jumphost": "",
"type": "connection"
}
},
"profiles": {
"default": {"host": "", "protocol": "ssh", "port": "",
"user": "", "password": "", "options": "",
"logs": "", "tags": "", "jumphost": ""},
"myprofile": {"host": "", "protocol": "ssh", "port": "",
"user": "profuser", "password": "profpass",
"options": "", "logs": "", "tags": "", "jumphost": ""}
}
}
config_file.write_text(json.dumps(data, indent=4))
from connpy.configfile import configfile
conf = configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
used = conf._profileused("myprofile")
assert "router1" in used
assert "router2" not in used
def test_saveconfig(self, config):
"""Save and reload correctly."""
config._connections_add(
id="test_node", host="1.2.3.4", protocol="ssh",
port="", user="", password="", options="",
logs="", tags="", jumphost=""
)
result = config._saveconfig(config.file)
assert result == 0
# Reload and verify
from connpy.configfile import configfile
reloaded = configfile(conf=config.file, key=config.key)
assert "test_node" in reloaded.connections
+419
View File
@@ -0,0 +1,419 @@
"""Tests for connpy.core module — node and nodes classes."""
import json
import os
import io
import re
import pytest
from unittest.mock import patch, MagicMock, PropertyMock
from copy import deepcopy
# =========================================================================
# node.__init__ tests
# =========================================================================
class TestNodeInit:
def test_basic_init(self):
"""Creates node with basic attributes."""
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="pass1", protocol="ssh")
assert n.unique == "router1"
assert n.host == "10.0.0.1"
assert n.user == "admin"
assert n.protocol == "ssh"
assert n.password == ["pass1"]
def test_default_protocol(self):
"""Default protocol is ssh."""
from connpy.core import node
n = node("router1", "10.0.0.1")
assert n.protocol == "ssh"
def test_password_as_list_of_profiles(self, populated_config):
"""Password list with @profile references resolves correctly."""
from connpy.core import node
n = node("router1", "10.0.0.1", password=["@office-user"],
config=populated_config)
assert n.password == ["officepass"]
def test_password_plain_string(self):
"""Plain string password is wrapped in a list."""
from connpy.core import node
n = node("router1", "10.0.0.1", password="mypass")
assert n.password == ["mypass"]
def test_node_with_profile(self, populated_config):
"""Resolves @profile references for user."""
from connpy.core import node
n = node("test1", "10.0.0.1", user="@office-user", password="plain",
config=populated_config)
assert n.user == "officeadmin"
def test_node_tags(self):
"""Tags are stored correctly."""
from connpy.core import node
tags = {"os": "cisco_ios", "prompt": r"Router#"}
n = node("router1", "10.0.0.1", tags=tags)
assert n.tags["os"] == "cisco_ios"
# =========================================================================
# Command generation tests
# =========================================================================
class TestCommandGeneration:
def _make_node(self, **kwargs):
from connpy.core import node
defaults = {
"unique": "test", "host": "10.0.0.1", "protocol": "ssh",
"user": "admin", "password": "", "port": "", "options": "",
"jumphost": "", "tags": "", "logs": ""
}
defaults.update(kwargs)
return node(defaults.pop("unique"), defaults.pop("host"), **defaults)
def test_ssh_cmd_basic(self):
n = self._make_node()
cmd = n._get_cmd()
assert "ssh" in cmd
assert "admin@10.0.0.1" in cmd
def test_ssh_cmd_port(self):
n = self._make_node(port="2222")
cmd = n._get_cmd()
assert "-p 2222" in cmd
def test_ssh_cmd_options(self):
n = self._make_node(options="-o StrictHostKeyChecking=no")
cmd = n._get_cmd()
assert "-o StrictHostKeyChecking=no" in cmd
def test_sftp_cmd_port(self):
n = self._make_node(protocol="sftp", port="2222")
cmd = n._get_cmd()
assert "-P 2222" in cmd # SFTP uses uppercase P
def test_telnet_cmd(self):
n = self._make_node(protocol="telnet", port="23")
cmd = n._get_cmd()
assert "telnet 10.0.0.1" in cmd
assert "23" in cmd
def test_kubectl_cmd(self):
n = self._make_node(protocol="kubectl", host="my-pod", tags={"kube_command": "/bin/sh"})
cmd = n._get_cmd()
assert "kubectl exec" in cmd
assert "my-pod" in cmd
assert "/bin/sh" in cmd
def test_kubectl_cmd_default_command(self):
n = self._make_node(protocol="kubectl", host="my-pod")
cmd = n._get_cmd()
assert "/bin/bash" in cmd
def test_docker_cmd(self):
n = self._make_node(protocol="docker", host="my-container",
tags={"docker_command": "/bin/sh"})
cmd = n._get_cmd()
assert "docker" in cmd
assert "my-container" in cmd
assert "/bin/sh" in cmd
def test_invalid_protocol_raises(self):
n = self._make_node(protocol="invalid_proto")
with pytest.raises(ValueError, match="Invalid protocol"):
n._get_cmd()
def test_ssh_cmd_no_user(self):
n = self._make_node(user="")
cmd = n._get_cmd()
assert "10.0.0.1" in cmd
assert "@" not in cmd # No user@ prefix
# =========================================================================
# Password decryption tests
# =========================================================================
class TestPasswordDecryption:
def test_passtx_plaintext(self, config):
"""Plaintext passwords pass through unchanged."""
from connpy.core import node
n = node("test", "10.0.0.1", password="plainpass", config=config)
result = n._passtx(["plainpass"])
assert result == ["plainpass"]
def test_passtx_encrypted(self, config):
"""Encrypted passwords get decrypted."""
from connpy.core import node
encrypted = config.encrypt("mysecret")
n = node("test", "10.0.0.1", password=encrypted, config=config)
result = n._passtx([encrypted])
assert result == ["mysecret"]
def test_passtx_missing_key_raises(self):
"""Missing key file raises ValueError."""
from connpy.core import node
n = node("test", "10.0.0.1", password="pass")
# A password formatted as encrypted but no valid key
with pytest.raises((ValueError, Exception)):
n._passtx(["""b'corrupted_encrypted_data'"""], keyfile="/nonexistent")
# =========================================================================
# Log handling tests
# =========================================================================
class TestLogHandling:
def test_logfile_variable_substitution(self):
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", protocol="ssh", port="22",
logs="/logs/${unique}_${host}_${user}")
result = n._logfile()
assert result == "/logs/router1_10.0.0.1_admin"
def test_logfile_date_substitution(self):
from connpy.core import node
import datetime
n = node("router1", "10.0.0.1", logs="/logs/${date '%Y'}")
result = n._logfile()
assert datetime.datetime.now().strftime("%Y") in result
def test_logclean_removes_ansi(self):
from connpy.core import node
n = node("test", "10.0.0.1")
dirty = "\x1B[32mgreen text\x1B[0m"
clean = n._logclean(dirty, var=True)
assert "\x1B" not in clean
assert "green text" in clean
def test_logclean_removes_backspaces(self):
from connpy.core import node
n = node("test", "10.0.0.1")
dirty = "type\bo"
clean = n._logclean(dirty, var=True)
assert "\b" not in clean
# =========================================================================
# run() and test() with mock pexpect
# =========================================================================
class TestNodeRun:
def _make_connected_node(self, mock_pexpect_obj, **kwargs):
"""Create a node and mock its _connect to succeed."""
from connpy.core import node
defaults = {
"unique": "router1", "host": "10.0.0.1",
"protocol": "ssh", "user": "admin", "password": ""
}
defaults.update(kwargs)
n = node(defaults.pop("unique"), defaults.pop("host"), **defaults)
return n
def test_run_returns_output(self, mock_pexpect):
"""run() returns string output."""
child = mock_pexpect["child"]
pexp = mock_pexpect["pexpect"]
# Simulate: connect succeeds, command runs, prompt found
child.expect.return_value = 9 # prompt index for ssh
child.logfile_read = None
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="")
# Mock _connect to return True and set up child
with patch.object(n, '_connect', return_value=True):
n.child = child
log_buffer = io.BytesIO(b"show version\nRouter v1.0\nrouter#")
n.mylog = log_buffer
child.logfile_read = log_buffer
with patch.object(n, '_logclean', return_value="Router v1.0"):
output = n.run(["show version"])
assert n.status == 0
assert output == "Router v1.0"
def test_run_status_1_on_failure(self, mock_pexpect):
"""Status 1 when connection fails."""
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="")
with patch.object(n, '_connect', return_value="Connection failed code: 1\nrefused"):
output = n.run(["show version"])
assert n.status == 1
assert "refused" in output
def test_run_with_variables(self, mock_pexpect):
"""Variables get substituted in commands."""
child = mock_pexpect["child"]
child.expect.return_value = 9
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="")
sent_commands = []
child.sendline.side_effect = lambda cmd: sent_commands.append(cmd)
with patch.object(n, '_connect', return_value=True):
n.child = child
n.mylog = io.BytesIO(b"output")
with patch.object(n, '_logclean', return_value="output"):
n.run(["show ip route {subnet}"], vars={"subnet": "10.0.0.0/24"})
assert "show ip route 10.0.0.0/24" in sent_commands
def test_run_saves_to_folder(self, mock_pexpect, tmp_path):
"""folder param saves log file."""
child = mock_pexpect["child"]
child.expect.return_value = 9
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="")
with patch.object(n, '_connect', return_value=True):
n.child = child
n.mylog = io.BytesIO(b"log output")
with patch.object(n, '_logclean', return_value="log output"):
n.run(["show version"], folder=str(tmp_path))
log_files = list(tmp_path.glob("router1_*.txt"))
assert len(log_files) == 1
assert "log output" in log_files[0].read_text()
class TestNodeTest:
def test_test_returns_dict(self, mock_pexpect):
"""test() returns dict of results."""
child = mock_pexpect["child"]
child.expect.return_value = 0 # prompt found (index 0 in test expects)
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="")
with patch.object(n, '_connect', return_value=True):
n.child = child
n.mylog = io.BytesIO(b"1.1.1.1 is up")
with patch.object(n, '_logclean', return_value="1.1.1.1 is up"):
result = n.test(["ping 1.1.1.1"], "1.1.1.1")
assert isinstance(result, dict)
assert result.get("1.1.1.1") == True
def test_test_expected_not_found(self, mock_pexpect):
"""Expected text not found returns False."""
child = mock_pexpect["child"]
child.expect.return_value = 0
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="")
with patch.object(n, '_connect', return_value=True):
n.child = child
n.mylog = io.BytesIO(b"some other output")
with patch.object(n, '_logclean', return_value="some other output"):
result = n.test(["ping 1.1.1.1"], "1.1.1.1")
assert isinstance(result, dict)
assert result.get("1.1.1.1") == False
# =========================================================================
# nodes (parallel) tests
# =========================================================================
class TestNodes:
def test_nodes_init(self):
"""Creates list of node objects."""
from connpy.core import nodes
nodes_dict = {
"r1": {"host": "10.0.0.1", "user": "admin", "password": ""},
"r2": {"host": "10.0.0.2", "user": "admin", "password": ""}
}
mynodes = nodes(nodes_dict)
assert len(mynodes.nodelist) == 2
assert hasattr(mynodes, "r1")
assert hasattr(mynodes, "r2")
def test_nodes_run_parallel(self):
"""run() executes on all nodes and returns dict."""
from connpy.core import nodes
nodes_dict = {
"r1": {"host": "10.0.0.1", "user": "admin", "password": ""},
"r2": {"host": "10.0.0.2", "user": "admin", "password": ""}
}
mynodes = nodes(nodes_dict)
# Mock run on each node — must set output AND status on the node
for n in mynodes.nodelist:
original_node = n # capture by value
def make_mock(node_ref):
def mock_run(commands, **kwargs):
node_ref.output = f"output from {node_ref.unique}"
node_ref.status = 0
return mock_run
n.run = make_mock(n)
result = mynodes.run(["show version"])
assert "r1" in result
assert "r2" in result
def test_nodes_splitlist(self):
"""_splitlist divides list correctly."""
from connpy.core import nodes
mynodes = nodes({"r1": {"host": "1.1.1.1", "user": "", "password": ""}})
chunks = list(mynodes._splitlist([1, 2, 3, 4, 5], 2))
assert chunks == [[1, 2], [3, 4], [5]]
def test_nodes_run_with_vars(self):
"""Variables per node and __global__ work."""
from connpy.core import nodes
nodes_dict = {
"r1": {"host": "10.0.0.1", "user": "admin", "password": ""},
}
mynodes = nodes(nodes_dict)
captured_vars = {}
def mock_run(commands, vars=None, **kwargs):
captured_vars.update(vars or {})
mynodes.r1.output = "ok"
mynodes.r1.status = 0
mynodes.r1.run = mock_run
variables = {
"__global__": {"mask": "255.255.255.0"},
"r1": {"ip": "10.0.0.1"}
}
mynodes.run(["show ip"], vars=variables)
assert captured_vars.get("mask") == "255.255.255.0"
assert captured_vars.get("ip") == "10.0.0.1"
def test_nodes_on_complete_callback(self):
"""on_complete callback fires per node."""
from connpy.core import nodes
nodes_dict = {
"r1": {"host": "10.0.0.1", "user": "admin", "password": ""},
}
mynodes = nodes(nodes_dict)
completed = []
def mock_run(commands, **kwargs):
mynodes.r1.output = "done"
mynodes.r1.status = 0
mynodes.r1.run = mock_run
def on_done(unique, output, status):
completed.append(unique)
mynodes.run(["show version"], on_complete=on_done)
assert "r1" in completed
+216
View File
@@ -0,0 +1,216 @@
"""Tests for connpy.hooks module — MethodHook and ClassHook."""
import pytest
from connpy.hooks import MethodHook, ClassHook
# =========================================================================
# MethodHook Tests
# =========================================================================
class TestMethodHook:
def test_basic_call(self):
"""Decorated function executes normally."""
@MethodHook
def add(a, b):
return a + b
assert add(2, 3) == 5
def test_pre_hook_modifies_args(self):
"""Pre-hook can modify arguments before execution."""
@MethodHook
def greet(name):
return f"Hello {name}"
def uppercase_hook(name):
return (name.upper(),), {}
greet.register_pre_hook(uppercase_hook)
assert greet("world") == "Hello WORLD"
def test_post_hook_modifies_result(self):
"""Post-hook can modify the return value."""
@MethodHook
def compute(x):
return x * 2
def double_result(*args, **kwargs):
return kwargs["result"] * 2
compute.register_post_hook(double_result)
assert compute(5) == 20 # 5*2=10, then 10*2=20
def test_multiple_pre_hooks_order(self):
"""Pre-hooks execute in registration order."""
calls = []
@MethodHook
def func(x):
return x
def hook1(x):
calls.append("hook1")
return (x,), {}
def hook2(x):
calls.append("hook2")
return (x,), {}
func.register_pre_hook(hook1)
func.register_pre_hook(hook2)
func(1)
assert calls == ["hook1", "hook2"]
def test_multiple_post_hooks_order(self):
"""Post-hooks execute in registration order."""
calls = []
@MethodHook
def func(x):
return x
def hook1(*args, **kwargs):
calls.append("hook1")
return kwargs["result"]
def hook2(*args, **kwargs):
calls.append("hook2")
return kwargs["result"]
func.register_post_hook(hook1)
func.register_post_hook(hook2)
func(1)
assert calls == ["hook1", "hook2"]
def test_pre_hook_exception_continues(self, capsys):
"""If a pre-hook raises, the function still executes."""
@MethodHook
def func(x):
return x + 1
def bad_hook(x):
raise RuntimeError("broken hook")
func.register_pre_hook(bad_hook)
# Should not raise — the hook error is printed but execution continues
result = func(5)
assert result == 6
def test_post_hook_exception_continues(self, capsys):
"""If a post-hook raises, the result is still returned."""
@MethodHook
def func(x):
return x + 1
def bad_hook(*args, **kwargs):
raise RuntimeError("broken post hook")
func.register_post_hook(bad_hook)
result = func(5)
assert result == 6
def test_method_hook_as_instance_method(self):
"""MethodHook works as a descriptor on a class."""
class MyClass:
@MethodHook
def double(self, x):
return x * 2
obj = MyClass()
assert obj.double(5) == 10
def test_method_hook_instance_hook_registration(self):
"""Can register hooks via instance method access."""
class MyClass:
@MethodHook
def process(self, x):
return x
def add_ten(*args, **kwargs):
return kwargs["result"] + 10
obj = MyClass()
obj.process.register_post_hook(add_ten)
assert obj.process(5) == 15
# =========================================================================
# ClassHook Tests
# =========================================================================
class TestClassHook:
def test_creates_instance(self):
"""ClassHook still creates instances normally."""
@ClassHook
class MyClass:
def __init__(self, value):
self.value = value
obj = MyClass(42)
assert obj.value == 42
def test_modify_future_instances(self):
"""modify() affects all future instances."""
@ClassHook
class MyClass:
def __init__(self):
self.x = 1
def set_x_to_99(instance):
instance.x = 99
MyClass.modify(set_x_to_99)
obj = MyClass()
assert obj.x == 99
def test_modify_does_not_affect_past(self):
"""modify() does not affect already-created instances."""
@ClassHook
class MyClass:
def __init__(self):
self.x = 1
old_obj = MyClass()
def set_x_to_99(instance):
instance.x = 99
MyClass.modify(set_x_to_99)
assert old_obj.x == 1 # Not affected
assert MyClass().x == 99 # New instance IS affected
def test_instance_modify(self):
"""instance.modify() only affects that specific instance."""
@ClassHook
class MyClass:
def __init__(self):
self.x = 1
obj1 = MyClass()
obj2 = MyClass()
obj1.modify(lambda inst: setattr(inst, 'x', 999))
assert obj1.x == 999
assert obj2.x == 1
def test_multiple_deferred_hooks(self):
"""Multiple modify() calls apply in order."""
@ClassHook
class MyClass:
def __init__(self):
self.log = []
MyClass.modify(lambda inst: inst.log.append("first"))
MyClass.modify(lambda inst: inst.log.append("second"))
obj = MyClass()
assert obj.log == ["first", "second"]
def test_getattr_delegation(self):
"""ClassHook delegates attribute access to the wrapped class."""
@ClassHook
class MyClass:
class_var = "hello"
def __init__(self):
pass
assert MyClass.class_var == "hello"
+327
View File
@@ -0,0 +1,327 @@
"""Tests for connpy.plugins module."""
import os
import textwrap
import pytest
from connpy.plugins import Plugins
# ---------------------------------------------------------------------------
# Helper: write a plugin script to a file
# ---------------------------------------------------------------------------
def _write_plugin(path, code):
"""Write dedented code to a file."""
with open(path, "w") as f:
f.write(textwrap.dedent(code))
# =========================================================================
# verify_script tests
# =========================================================================
class TestVerifyScript:
def test_valid_parser_entrypoint(self, tmp_path):
p = tmp_path / "good.py"
_write_plugin(p, """\
import argparse
class Parser:
def __init__(self):
self.parser = argparse.ArgumentParser()
class Entrypoint:
def __init__(self, args, parser, connapp):
pass
""")
plugins = Plugins()
assert plugins.verify_script(str(p)) == False
def test_valid_preload_only(self, tmp_path):
p = tmp_path / "preload.py"
_write_plugin(p, """\
class Preload:
def __init__(self, connapp):
pass
""")
plugins = Plugins()
assert plugins.verify_script(str(p)) == False
def test_valid_all_three(self, tmp_path):
p = tmp_path / "all.py"
_write_plugin(p, """\
import argparse
class Parser:
def __init__(self):
self.parser = argparse.ArgumentParser()
class Entrypoint:
def __init__(self, args, parser, connapp):
pass
class Preload:
def __init__(self, connapp):
pass
""")
plugins = Plugins()
assert plugins.verify_script(str(p)) == False
def test_parser_without_entrypoint(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
import argparse
class Parser:
def __init__(self):
self.parser = argparse.ArgumentParser()
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result # Should be a truthy error string
assert "Entrypoint" in result
def test_entrypoint_without_parser(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
class Entrypoint:
def __init__(self, args, parser, connapp):
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "Parser" in result
def test_no_valid_class(self, tmp_path):
p = tmp_path / "empty.py"
_write_plugin(p, """\
def some_function():
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "No valid class" in result
def test_parser_missing_self_parser(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
class Parser:
def __init__(self):
self.something = "not parser"
class Entrypoint:
def __init__(self, args, parser, connapp):
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "self.parser" in result
def test_entrypoint_wrong_args(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
import argparse
class Parser:
def __init__(self):
self.parser = argparse.ArgumentParser()
class Entrypoint:
def __init__(self, args):
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "Entrypoint" in result
def test_preload_wrong_args(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
class Preload:
def __init__(self, connapp, extra):
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "Preload" in result
def test_disallowed_top_level(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
MY_GLOBAL = "not allowed"
class Preload:
def __init__(self, connapp):
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "not allowed" in result.lower() or "Plugin can only have" in result
def test_syntax_error(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
def broken(
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "Syntax error" in result
def test_if_name_main_allowed(self, tmp_path):
p = tmp_path / "good.py"
_write_plugin(p, """\
class Preload:
def __init__(self, connapp):
pass
if __name__ == "__main__":
print("standalone")
""")
plugins = Plugins()
assert plugins.verify_script(str(p)) == False
def test_other_if_not_allowed(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
import sys
if sys.platform == "linux":
pass
class Preload:
def __init__(self, connapp):
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "__name__" in result
# =========================================================================
# Import and loading tests
# =========================================================================
class TestPluginLoading:
def test_import_from_path(self, tmp_path):
p = tmp_path / "mymod.py"
_write_plugin(p, """\
MY_VAR = 42
""")
plugins = Plugins()
module = plugins._import_from_path(str(p))
assert module.MY_VAR == 42
def test_import_plugins_to_argparse(self, tmp_path):
"""Valid plugins get loaded into argparse."""
import argparse
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
_write_plugin(plugin_dir / "myplugin.py", """\
import argparse
class Parser:
def __init__(self):
self.parser = argparse.ArgumentParser(description="My plugin")
class Entrypoint:
def __init__(self, args, parser, connapp):
pass
""")
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
plugins = Plugins()
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
assert "myplugin" in plugins.plugins
assert "myplugin" in plugins.plugin_parsers
def test_plugin_name_collision(self, tmp_path):
"""Plugin with same name as existing subcommand is skipped."""
import argparse
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
_write_plugin(plugin_dir / "existcmd.py", """\
import argparse
class Parser:
def __init__(self):
self.parser = argparse.ArgumentParser()
class Entrypoint:
def __init__(self, args, parser, connapp):
pass
""")
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
subparsers.add_parser("existcmd") # Already taken
plugins = Plugins()
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
assert "existcmd" not in plugins.plugins
def test_preload_registration(self, tmp_path):
"""Preload class gets registered in preloads dict."""
import argparse
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
_write_plugin(plugin_dir / "preloader.py", """\
class Preload:
def __init__(self, connapp):
pass
""")
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
plugins = Plugins()
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
assert "preloader" in plugins.preloads
def test_invalid_plugin_skipped(self, tmp_path, capsys):
"""Invalid plugin is skipped with error message."""
import argparse
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
_write_plugin(plugin_dir / "badplugin.py", """\
MY_GLOBAL = "bad"
""")
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
plugins = Plugins()
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
assert "badplugin" not in plugins.plugins
captured = capsys.readouterr()
assert "Failed to load plugin" in captured.err or "Failed to load plugin" in captured.out
def test_empty_directory(self, tmp_path):
"""Empty directory doesn't cause errors."""
import argparse
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
plugins = Plugins()
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
assert len(plugins.plugins) == 0
+50
View File
@@ -0,0 +1,50 @@
"""Tests for connpy.printer module."""
import sys
from io import StringIO
from connpy import printer
class TestPrinter:
def test_info_output(self, capsys):
printer.info("hello world")
captured = capsys.readouterr()
assert "[i] hello world" in captured.out
def test_success_output(self, capsys):
printer.success("done")
captured = capsys.readouterr()
assert "[✓] done" in captured.out
def test_warning_output(self, capsys):
printer.warning("careful")
captured = capsys.readouterr()
assert "[!] careful" in captured.out
def test_error_output(self, capsys):
printer.error("failed")
captured = capsys.readouterr()
assert "[✗] failed" in captured.err
def test_debug_output(self, capsys):
printer.debug("debug info")
captured = capsys.readouterr()
assert "[d] debug info" in captured.out
def test_start_output(self, capsys):
printer.start("starting")
captured = capsys.readouterr()
assert "[+] starting" in captured.out
def test_custom_output(self, capsys):
printer.custom("TAG", "custom message")
captured = capsys.readouterr()
assert "[TAG] custom message" in captured.out
def test_multiline_indentation(self, capsys):
printer.info("line1\nline2\nline3")
captured = capsys.readouterr()
lines = captured.out.strip().split("\n")
assert lines[0] == "[i] line1"
# Second line should be indented by len("[i] ") = 4 chars
assert lines[1].startswith(" line2")
assert lines[2].startswith(" line3")
+5
View File
@@ -0,0 +1,5 @@
[pytest]
testpaths = connpy/tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*