test: add comprehensive unit tests and fix bare excepts

- Added comprehensive unit test suite covering AI, core, plugins, completion, API, and hooks using pytest.
- Replaced bare 'except:' clauses across the codebase with specific exception handling (e.g., ValueError, KeyError, OSError) to prevent swallowing system exit calls.
- Fixed Python 3.8+ AST compatibility in plugins.py (support for ast.Constant).
- Removed deprecated pkg_resources import from __init__.py.
- Fixed missing 'printer' import in configfile.py that caused NameErrors during save failures.
This commit is contained in:
2026-04-03 17:11:45 -03:00
parent 7de6003435
commit cf95befb43
21 changed files with 2490 additions and 56 deletions
-1
View File
@@ -542,7 +542,6 @@ from .api import *
from .ai import ai
from .plugins import Plugins
from ._version import __version__
from pkg_resources import get_distribution
from . import printer
__all__ = ["node", "nodes", "configfile", "connapp", "ai", "Plugins", "printer"]
+7 -6
View File
@@ -396,7 +396,7 @@ class ai:
if isinstance(commands, str):
try:
commands = json.loads(commands)
except:
except ValueError:
commands = [c.strip() for c in commands.split('\n') if c.strip()]
# Expand multi-line commands within a list (in case the AI packs them)
@@ -795,9 +795,10 @@ class ai:
response = completion(model=model, messages=safe_messages, tools=[], api_key=key)
resp_msg = response.choices[0].message
messages.append(resp_msg.model_dump(exclude_none=True))
except:
pass
except Exception as e:
if status:
status.update(f"[bold red]Error fetching summary: {e}[/bold red]")
printer.warning(f"Failed to fetch final summary from LLM: {e}")
except KeyboardInterrupt:
if status: status.update("[bold red]Interrupted! Closing pending tasks...")
last_msg = messages[-1]
@@ -810,7 +811,7 @@ class ai:
response = completion(model=model, messages=safe_messages, tools=tools, api_key=key)
resp_msg = response.choices[0].message
messages.append(resp_msg.model_dump(exclude_none=True))
except: pass
except Exception: pass
finally:
try:
log_dir = self.config.defaultdir
@@ -820,7 +821,7 @@ class ai:
if os.path.exists(log_path):
try:
with open(log_path, "r") as f: hist = json.load(f)
except: hist = []
except (IOError, json.JSONDecodeError): hist = []
hist.append({"timestamp": datetime.datetime.now().isoformat(), "roles": {"strategic_engine": self.architect_model, "execution_engine": self.engineer_model}, "session": messages})
with open(log_path, "w") as f: json.dump(hist[-10:], f, indent=4)
except Exception as e:
+10 -10
View File
@@ -35,7 +35,7 @@ def list_nodes():
else:
filter = filter.lower()
output = conf._getallnodes(filter)
except:
except Exception:
output = conf._getallnodes()
return jsonify(output)
@@ -52,7 +52,7 @@ def get_nodes():
else:
filter = filter.lower()
output = conf._getallnodesfull(filter)
except:
except Exception:
output = conf._getallnodesfull()
return jsonify(output)
@@ -109,13 +109,13 @@ def run_commands():
mynodes = nodes(mynodes, config=conf)
try:
args["vars"] = data["vars"]
except:
except Exception:
pass
try:
options = data["options"]
thisoptions = {k: v for k, v in options.items() if k in ["prompt", "parallel", "timeout"]}
args.update(thisoptions)
except:
except Exception:
options = None
if action == "run":
output = mynodes.run(**args)
@@ -136,20 +136,20 @@ def stop_api():
pid = int(f.readline().strip())
port = int(f.readline().strip())
PID_FILE=PID_FILE1
except:
except (FileNotFoundError, ValueError, OSError):
try:
with open(PID_FILE2, "r") as f:
pid = int(f.readline().strip())
port = int(f.readline().strip())
PID_FILE=PID_FILE2
except:
except (FileNotFoundError, ValueError, OSError):
printer.warning("Connpy API server is not running.")
return
# Send a SIGTERM signal to the process
try:
os.kill(pid, signal.SIGTERM)
except:
pass
except OSError as e:
printer.warning(f"Process kill failed (maybe already dead): {e}")
# Delete the PID file
os.remove(PID_FILE)
printer.info(f"Server with process ID {pid} stopped.")
@@ -177,11 +177,11 @@ def start_api(port=8048):
try:
with open(PID_FILE1, "w") as f:
f.write(str(pid) + "\n" + str(port))
except:
except OSError:
try:
with open(PID_FILE2, "w") as f:
f.write(str(pid) + "\n" + str(port))
except:
except OSError:
printer.error("Couldn't create PID file.")
exit(1)
printer.start(f"Server is running with process ID {pid} on port {port}")
+2 -2
View File
@@ -95,7 +95,7 @@ def main():
try:
with open(pathfile, "r") as f:
configdir = f.read().strip()
except:
except (FileNotFoundError, IOError):
configdir = defaultdir
defaultfile = configdir + '/config.json'
jsonconf = open(defaultfile)
@@ -131,7 +131,7 @@ def main():
spec.loader.exec_module(module)
plugin_completion = getattr(module, "_connpy_completion")
strings = plugin_completion(wordsnumber, words, info)
except:
except Exception:
exit()
elif wordsnumber >= 3 and words[0] == "ai":
if wordsnumber == 3:
+10 -8
View File
@@ -8,6 +8,7 @@ from Crypto.Cipher import PKCS1_OAEP
from pathlib import Path
from copy import deepcopy
from .hooks import MethodHook, ClassHook
from . import printer
@@ -60,7 +61,7 @@ class configfile:
try:
with open(pathfile, "r") as f:
configdir = f.read().strip()
except:
except (FileNotFoundError, IOError):
with open(pathfile, "w") as f:
f.write(str(defaultdir))
configdir = defaultdir
@@ -120,7 +121,8 @@ class configfile:
with open(conf, "w") as f:
json.dump(newconfig, f, indent = 4)
f.close()
except:
except (IOError, OSError) as e:
printer.error(f"Failed to save config: {e}")
return 1
return 0
@@ -205,12 +207,12 @@ class configfile:
if profile:
try:
newfolder[node_name][key] = self.profiles[profile.group(1)][key]
except:
except KeyError:
newfolder[node_name][key] = ""
elif value == '' and key == "protocol":
try:
newfolder[node_name][key] = self.profiles["default"][key]
except:
except KeyError:
newfolder[node_name][key] = "ssh"
newfolder = {"{}{}".format(k,unique):v for k,v in newfolder.items()}
@@ -231,12 +233,12 @@ class configfile:
if profile:
try:
newnode[key] = self.profiles[profile.group(1)][key]
except:
except KeyError:
newnode[key] = ""
elif value == '' and key == "protocol":
try:
newnode[key] = self.profiles["default"][key]
except:
except KeyError:
newnode[key] = "ssh"
return newnode
@@ -391,12 +393,12 @@ class configfile:
if profile:
try:
nodes[node][key] = self.profiles[profile.group(1)][key]
except:
except KeyError:
nodes[node][key] = ""
elif value == '' and key == "protocol":
try:
nodes[node][key] = self.profiles["default"][key]
except:
except KeyError:
nodes[node][key] = "ssh"
return nodes
+16 -17
View File
@@ -17,7 +17,6 @@ import shutil
class NoAliasDumper(yaml.SafeDumper):
def ignore_aliases(self, data):
return True
import ast
from rich.markdown import Markdown
from rich.console import Console, Group
from rich.panel import Panel
@@ -29,7 +28,7 @@ mdprint = Console().print
console = Console()
try:
from pyfzf.pyfzf import FzfPrompt
except:
except ImportError:
FzfPrompt = None
@@ -64,7 +63,7 @@ class connapp:
self.case = self.config.config["case"]
try:
self.fzf = self.config.config["fzf"]
except:
except KeyError:
self.fzf = False
@@ -178,13 +177,13 @@ class connapp:
try:
core_path = os.path.dirname(os.path.realpath(__file__)) + "/core_plugins"
self.plugins._import_plugins_to_argparse(core_path, subparsers)
except:
pass
except Exception as e:
printer.warning(e)
try:
file_path = self.config.defaultdir + "/plugins"
self.plugins._import_plugins_to_argparse(file_path, subparsers)
except:
pass
except Exception as e:
printer.warning(e)
for preload in self.plugins.preloads.values():
preload.Preload(self)
#Generate helps
@@ -826,7 +825,7 @@ class connapp:
try:
with open(args.data[0]) as file:
imported = yaml.load(file, Loader=yaml.FullLoader)
except:
except Exception:
printer.error("failed reading file {}".format(args.data[0]))
exit(10)
for k,v in imported.items():
@@ -1013,7 +1012,7 @@ class connapp:
try:
with open(args.data[0]) as file:
scripts = yaml.load(file, Loader=yaml.FullLoader)
except:
except Exception:
printer.error("failed reading file {}".format(args.data[0]))
exit(10)
for script in scripts["tasks"]:
@@ -1053,13 +1052,13 @@ class connapp:
options = script["options"]
thisoptions = {k: v for k, v in options.items() if k in ["prompt", "parallel", "timeout"]}
args.update(thisoptions)
except:
except KeyError:
options = None
try:
size = str(os.get_terminal_size())
p = re.search(r'.*columns=([0-9]+)', size)
columns = int(p.group(1))
except:
except (ValueError, OSError):
columns = 80
PANEL_WIDTH = columns
@@ -1182,7 +1181,7 @@ class connapp:
raise inquirer.errors.ValidationError("", reason="Pick a port between 1-65535, @profile o leave empty")
try:
port = int(current)
except:
except ValueError:
port = 0
if current != "" and not 1 <= int(port) <= 65535:
raise inquirer.errors.ValidationError("", reason="Pick a port between 1-65535 or leave empty")
@@ -1194,7 +1193,7 @@ class connapp:
raise inquirer.errors.ValidationError("", reason="Pick a port between 1-6553/app5, @profile or leave empty")
try:
port = int(current)
except:
except ValueError:
port = 0
if current.startswith("@"):
if current[1:] not in self.profiles:
@@ -1220,7 +1219,7 @@ class connapp:
isdict = False
try:
isdict = ast.literal_eval(current)
except:
except Exception:
pass
if not isinstance (isdict, dict):
raise inquirer.errors.ValidationError("", reason="Tags should be a python dictionary.".format(current))
@@ -1232,7 +1231,7 @@ class connapp:
isdict = False
try:
isdict = ast.literal_eval(current)
except:
except Exception:
pass
if not isinstance (isdict, dict):
raise inquirer.errors.ValidationError("", reason="Tags should be a python dictionary.".format(current))
@@ -1316,7 +1315,7 @@ class connapp:
defaults["tags"] = ""
if "jumphost" not in defaults:
defaults["jumphost"] = ""
except:
except KeyError:
defaults = { "host":"", "protocol":"", "port":"", "user":"", "options":"", "logs":"" , "tags":"", "password":"", "jumphost":""}
node = {}
if edit == None:
@@ -1390,7 +1389,7 @@ class connapp:
defaults["tags"] = ""
if "jumphost" not in defaults:
defaults["jumphost"] = ""
except:
except KeyError:
defaults = { "host":"", "protocol":"", "port":"", "user":"", "options":"", "logs":"", "tags": "", "jumphost": ""}
profile = {}
if edit == None:
+5 -5
View File
@@ -84,12 +84,12 @@ class node:
if profile and config != '':
try:
setattr(self,key,config.profiles[profile.group(1)][key])
except:
except KeyError:
setattr(self,key,"")
elif attr[key] == '' and key == "protocol":
try:
setattr(self,key,config.profiles["default"][key])
except:
except (KeyError, AttributeError):
setattr(self,key,"ssh")
else:
setattr(self,key,attr[key])
@@ -108,12 +108,12 @@ class node:
if profile:
try:
self.jumphost[key] = config.profiles[profile.group(1)][key]
except:
except KeyError:
self.jumphost[key] = ""
elif self.jumphost[key] == '' and key == "protocol":
try:
self.jumphost[key] = config.profiles["default"][key]
except:
except KeyError:
self.jumphost[key] = "ssh"
if isinstance(self.jumphost["password"],list):
jumphost_password = []
@@ -158,7 +158,7 @@ class node:
try:
decrypted = decryptor.decrypt(ast.literal_eval(passwd)).decode("utf-8")
dpass.append(decrypted)
except:
except Exception:
raise ValueError("Missing or corrupted key")
return dpass
+3 -3
View File
@@ -176,7 +176,7 @@ class RemoteCapture:
printer.success("Tcpdump finished capturing packets.")
self.listener_active = False
except:
except Exception:
pass
def _sendline_until_connected(self, cmd, retries=5, interval=2):
@@ -307,7 +307,7 @@ class RemoteCapture:
try:
self.fake_connection = True
socket.create_connection(("localhost", self.local_port), timeout=1).close()
except:
except OSError:
pass
self.listener_active = False
return
@@ -324,7 +324,7 @@ class RemoteCapture:
try:
self.listener_conn.shutdown(socket.SHUT_RDWR)
self.listener_conn.close()
except:
except OSError:
pass
if hasattr(self.node, "child"):
self.node.child.close(force=True)
+2 -2
View File
@@ -28,7 +28,7 @@ class sync:
self.connapp = connapp
try:
self.sync = self.connapp.config.config["sync"]
except:
except KeyError:
self.sync = False
def login(self):
@@ -322,7 +322,7 @@ class sync:
def config_listener_pre(self, *args, **kwargs):
try:
self.sync = self.connapp.config.config["sync"]
except:
except KeyError:
self.sync = False
return args, kwargs
+2 -2
View File
@@ -62,8 +62,8 @@ class Plugins:
if not (isinstance(node.test, ast.Compare) and
isinstance(node.test.left, ast.Name) and
node.test.left.id == '__name__' and
isinstance(node.test.comparators[0], ast.Str) and
node.test.comparators[0].s == '__main__'):
((hasattr(ast, 'Str') and isinstance(node.test.comparators[0], getattr(ast, 'Str')) and node.test.comparators[0].s == '__main__') or
(hasattr(ast, 'Constant') and isinstance(node.test.comparators[0], getattr(ast, 'Constant')) and node.test.comparators[0].value == '__main__'))):
return "Only __name__ == __main__ If is allowed"
elif not isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Import, ast.ImportFrom, ast.Pass)):
+1
View File
@@ -0,0 +1 @@
# Tests package
+192
View File
@@ -0,0 +1,192 @@
"""Shared fixtures for connpy tests.
All tests use tmp_path to create isolated config/keys.
No test touches ~/.config/conn/
"""
import pytest
import json
import os
from unittest.mock import patch, MagicMock
from Crypto.PublicKey import RSA
# ---------------------------------------------------------------------------
# Minimal config data
# ---------------------------------------------------------------------------
DEFAULT_CONFIG = {
"config": {"case": False, "idletime": 30, "fzf": False},
"connections": {},
"profiles": {
"default": {
"host": "", "protocol": "ssh", "port": "", "user": "",
"password": "", "options": "", "logs": "", "tags": "", "jumphost": ""
}
}
}
SAMPLE_CONNECTIONS = {
"router1": {
"host": "10.0.0.1", "protocol": "ssh", "port": "22",
"user": "admin", "password": "pass1", "options": "",
"logs": "", "tags": "", "jumphost": "", "type": "connection"
},
"office": {
"type": "folder",
"server1": {
"host": "10.0.1.1", "protocol": "ssh", "port": "",
"user": "root", "password": "pass2", "options": "",
"logs": "", "tags": "", "jumphost": "", "type": "connection"
},
"datacenter": {
"type": "subfolder",
"db1": {
"host": "10.0.2.1", "protocol": "ssh", "port": "",
"user": "dbadmin", "password": "pass3", "options": "",
"logs": "", "tags": "", "jumphost": "", "type": "connection"
}
}
}
}
SAMPLE_PROFILES = {
"default": {
"host": "", "protocol": "ssh", "port": "", "user": "",
"password": "", "options": "", "logs": "", "tags": "", "jumphost": ""
},
"office-user": {
"host": "", "protocol": "ssh", "port": "", "user": "officeadmin",
"password": "officepass", "options": "", "logs": "", "tags": "", "jumphost": ""
}
}
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def tmp_config_dir(tmp_path):
"""Create an isolated config directory with config.json and RSA key."""
config_dir = tmp_path / ".config" / "conn"
config_dir.mkdir(parents=True)
plugins_dir = config_dir / "plugins"
plugins_dir.mkdir()
# Write config.json
config_file = config_dir / "config.json"
config_file.write_text(json.dumps(DEFAULT_CONFIG, indent=4))
os.chmod(str(config_file), 0o600)
# Write .folder (points to itself)
folder_file = config_dir / ".folder"
folder_file.write_text(str(config_dir))
# Generate RSA key
key = RSA.generate(2048)
key_file = config_dir / ".osk"
key_file.write_bytes(key.export_key("PEM"))
os.chmod(str(key_file), 0o600)
return config_dir
@pytest.fixture
def config(tmp_config_dir):
"""Create a configfile instance pointing to tmp directory."""
from connpy.configfile import configfile
conf_path = str(tmp_config_dir / "config.json")
key_path = str(tmp_config_dir / ".osk")
return configfile(conf=conf_path, key=key_path)
@pytest.fixture
def populated_config(tmp_config_dir):
"""Create a configfile with sample nodes/profiles pre-loaded."""
config_file = tmp_config_dir / "config.json"
data = {
"config": {"case": False, "idletime": 30, "fzf": False},
"connections": SAMPLE_CONNECTIONS,
"profiles": SAMPLE_PROFILES
}
config_file.write_text(json.dumps(data, indent=4))
from connpy.configfile import configfile
return configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
@pytest.fixture
def mock_pexpect():
"""Mock pexpect.spawn for connection tests."""
with patch("connpy.core.pexpect") as mock_pexp:
child = MagicMock()
child.before = b""
child.after = b"router#"
child.readline.return_value = b""
child.child_fd = 3
mock_pexp.spawn.return_value = child
mock_pexp.EOF = object()
mock_pexp.TIMEOUT = object()
# Also mock fdpexpect
with patch("connpy.core.fdpexpect", create=True) as mock_fd:
mock_fd.fdspawn.return_value = MagicMock()
yield {
"pexpect": mock_pexp,
"child": child,
"fdpexpect": mock_fd
}
@pytest.fixture
def mock_litellm():
"""Mock litellm.completion for AI tests."""
with patch("connpy.ai.completion") as mock_comp:
# Create a default response
msg = MagicMock()
msg.content = "Test response from AI"
msg.tool_calls = None
msg.role = "assistant"
msg.model_dump.return_value = {
"role": "assistant",
"content": "Test response from AI"
}
choice = MagicMock()
choice.message = msg
response = MagicMock()
response.choices = [choice]
response.usage = MagicMock()
response.usage.prompt_tokens = 100
response.usage.completion_tokens = 50
response.usage.total_tokens = 150
mock_comp.return_value = response
yield {
"completion": mock_comp,
"response": response,
"message": msg,
"choice": choice
}
@pytest.fixture
def ai_config(tmp_config_dir):
"""Create a configfile with AI keys configured for AI tests."""
config_file = tmp_config_dir / "config.json"
data = {
"config": {
"case": False, "idletime": 30, "fzf": False,
"ai": {
"engineer_model": "test/test-model",
"engineer_api_key": "test-engineer-key",
"architect_model": "test/test-architect",
"architect_api_key": "test-architect-key"
}
},
"connections": SAMPLE_CONNECTIONS,
"profiles": SAMPLE_PROFILES
}
config_file.write_text(json.dumps(data, indent=4))
from connpy.configfile import configfile
return configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
+397
View File
@@ -0,0 +1,397 @@
"""Tests for connpy.ai module."""
import json
import os
import pytest
from unittest.mock import patch, MagicMock
# =========================================================================
# AI Init tests
# =========================================================================
class TestAIInit:
def test_init_with_keys(self, ai_config, mock_litellm):
"""Initializes correctly when keys are configured."""
from connpy.ai import ai
myai = ai(ai_config)
assert myai.engineer_model == "test/test-model"
assert myai.architect_model == "test/test-architect"
def test_init_missing_engineer_key(self, config):
"""Raises ValueError if engineer key is missing."""
from connpy.ai import ai
with pytest.raises(ValueError, match="Engineer API key"):
ai(config)
def test_init_missing_architect_key_warns(self, ai_config, capsys, mock_litellm):
"""Warns if architect key is missing but doesn't crash."""
# Remove architect key
ai_config.config["ai"]["architect_api_key"] = None
from connpy.ai import ai
# Should not raise
myai = ai(ai_config)
assert myai.architect_key is None
def test_default_models(self, config):
"""Default models are set correctly when not configured."""
config.config["ai"] = {"engineer_api_key": "test-key", "architect_api_key": "test-key"}
from connpy.ai import ai
myai = ai(config)
assert "gemini" in myai.engineer_model.lower()
assert "claude" in myai.architect_model.lower() or "anthropic" in myai.architect_model.lower()
def test_init_loads_memory(self, ai_config, tmp_path, mock_litellm):
"""Loads long-term memory from file if it exists."""
memory_path = os.path.expanduser("~/.config/conn/ai_memory.md")
from connpy.ai import ai
with patch("os.path.exists", side_effect=lambda p: True if p == memory_path else os.path.exists(p)):
with patch("builtins.open", side_effect=lambda f, *a, **kw: (
__import__("io").StringIO("## Memory\nRouter1 is border router")
if f == memory_path else open(f, *a, **kw)
)):
try:
myai = ai(ai_config)
except Exception:
pass # May fail on other file opens, that's ok
# =========================================================================
# register_ai_tool tests
# =========================================================================
class TestRegisterAITool:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def _make_tool_def(self, name="my_tool"):
return {
"type": "function",
"function": {
"name": name,
"description": "Test tool",
"parameters": {"type": "object", "properties": {}}
}
}
def test_register_tool_engineer(self, myai):
tool_def = self._make_tool_def()
myai.register_ai_tool(tool_def, lambda self, **kw: "ok", target="engineer")
assert len(myai.external_engineer_tools) == 1
assert len(myai.external_architect_tools) == 0
def test_register_tool_architect(self, myai):
tool_def = self._make_tool_def()
myai.register_ai_tool(tool_def, lambda self, **kw: "ok", target="architect")
assert len(myai.external_architect_tools) == 1
assert len(myai.external_engineer_tools) == 0
def test_register_tool_both(self, myai):
tool_def = self._make_tool_def()
myai.register_ai_tool(tool_def, lambda self, **kw: "ok", target="both")
assert len(myai.external_engineer_tools) == 1
assert len(myai.external_architect_tools) == 1
def test_register_tool_handler(self, myai):
tool_def = self._make_tool_def("custom_tool")
handler = lambda self, **kw: "result"
myai.register_ai_tool(tool_def, handler)
assert "custom_tool" in myai.external_tool_handlers
assert myai.external_tool_handlers["custom_tool"] is handler
def test_register_tool_prompt_extension(self, myai):
tool_def = self._make_tool_def()
myai.register_ai_tool(
tool_def, lambda self, **kw: "ok",
engineer_prompt="- Custom capability",
architect_prompt=" * Custom tool"
)
assert any("Custom capability" in ext for ext in myai.engineer_prompt_extensions)
assert any("Custom tool" in ext for ext in myai.architect_prompt_extensions)
def test_register_tool_status_formatter(self, myai):
tool_def = self._make_tool_def("status_tool")
formatter = lambda args: f"[STATUS] {args}"
myai.register_ai_tool(tool_def, lambda self, **kw: "ok", status_formatter=formatter)
assert "status_tool" in myai.tool_status_formatters
# =========================================================================
# Dynamic prompts tests
# =========================================================================
class TestDynamicPrompts:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def test_engineer_prompt_without_extensions(self, myai):
prompt = myai.engineer_system_prompt
assert "Plugin Capabilities" not in prompt
assert "TECHNICAL EXECUTION ENGINE" in prompt
def test_engineer_prompt_with_extensions(self, myai):
myai.engineer_prompt_extensions.append("- AWS Cloud Auditing")
prompt = myai.engineer_system_prompt
assert "Plugin Capabilities" in prompt
assert "AWS Cloud Auditing" in prompt
def test_architect_prompt_without_extensions(self, myai):
prompt = myai.architect_system_prompt
assert "Plugin Capabilities" not in prompt
assert "STRATEGIC REASONING ENGINE" in prompt
def test_architect_prompt_with_extensions(self, myai):
myai.architect_prompt_extensions.append(" * Custom tool available")
prompt = myai.architect_system_prompt
assert "Plugin Capabilities" in prompt
assert "Custom tool available" in prompt
# =========================================================================
# _sanitize_messages tests
# =========================================================================
class TestSanitizeMessages:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def test_sanitize_empty(self, myai):
assert myai._sanitize_messages([]) == []
def test_sanitize_normal_messages(self, myai):
messages = [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"}
]
result = myai._sanitize_messages(messages)
assert len(result) == 3
def test_sanitize_removes_orphan_tool_calls(self, myai):
"""Tool calls at the end without responses are removed."""
messages = [
{"role": "user", "content": "do something"},
{"role": "assistant", "content": None, "tool_calls": [
{"id": "tc1", "function": {"name": "list_nodes", "arguments": "{}"}}
]}
# No tool response follows!
]
result = myai._sanitize_messages(messages)
assert len(result) == 1 # Only user message
assert result[0]["role"] == "user"
def test_sanitize_removes_orphan_tool_responses(self, myai):
"""Tool responses without preceding tool_calls are removed."""
messages = [
{"role": "user", "content": "hello"},
{"role": "tool", "tool_call_id": "tc1", "name": "list_nodes", "content": "[]"}
]
result = myai._sanitize_messages(messages)
assert len(result) == 1
assert result[0]["role"] == "user"
def test_sanitize_preserves_valid_tool_pairs(self, myai):
"""Valid assistant+tool_calls followed by tool responses are preserved."""
messages = [
{"role": "user", "content": "list nodes"},
{"role": "assistant", "content": None, "tool_calls": [
{"id": "tc1", "function": {"name": "list_nodes", "arguments": "{}"}}
]},
{"role": "tool", "tool_call_id": "tc1", "name": "list_nodes", "content": "[\"r1\"]"},
{"role": "assistant", "content": "Found r1"}
]
result = myai._sanitize_messages(messages)
assert len(result) == 4
# =========================================================================
# _truncate tests
# =========================================================================
class TestTruncate:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def test_truncate_short_text(self, myai):
text = "short text"
assert myai._truncate(text) == text
def test_truncate_long_text(self, myai):
text = "x" * 100000
result = myai._truncate(text)
assert len(result) < 100000
assert "[... OUTPUT TRUNCATED ...]" in result
def test_truncate_custom_limit(self, myai):
text = "x" * 1000
result = myai._truncate(text, limit=500)
assert len(result) < 1000
assert "[... OUTPUT TRUNCATED ...]" in result
def test_truncate_preserves_head_and_tail(self, myai):
text = "HEAD" + "x" * 100000 + "TAIL"
result = myai._truncate(text)
assert result.startswith("HEAD")
assert result.endswith("TAIL")
# =========================================================================
# Tool methods tests
# =========================================================================
class TestToolMethods:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def test_list_nodes_tool_found(self, myai):
result = myai.list_nodes_tool("router.*")
parsed = json.loads(result)
assert "router1" in str(parsed)
def test_list_nodes_tool_not_found(self, myai):
result = myai.list_nodes_tool("nonexistent_pattern_xyz")
assert "No nodes found" in result
def test_get_node_info_masks_password(self, myai):
result = myai.get_node_info_tool("router1")
parsed = json.loads(result)
assert parsed["password"] == "***"
def test_is_safe_command_show(self, myai):
assert myai._is_safe_command("show running-config") == True
assert myai._is_safe_command("show ip int brief") == True
def test_is_safe_command_config(self, myai):
assert myai._is_safe_command("config t") == False
assert myai._is_safe_command("write memory") == False
def test_is_safe_command_ls(self, myai):
assert myai._is_safe_command("ls -la") == True
def test_is_safe_command_ping(self, myai):
assert myai._is_safe_command("ping 10.0.0.1") == True
# =========================================================================
# manage_memory_tool tests
# =========================================================================
class TestManageMemory:
@pytest.fixture
def myai(self, ai_config, mock_litellm, tmp_path):
from connpy.ai import ai
myai = ai(ai_config)
myai.memory_path = str(tmp_path / "ai_memory.md")
return myai
def test_manage_memory_append(self, myai):
result = myai.manage_memory_tool("Router1 is border router", action="append")
assert "successfully" in result.lower()
assert os.path.exists(myai.memory_path)
content = open(myai.memory_path).read()
assert "Router1 is border router" in content
def test_manage_memory_replace(self, myai):
myai.manage_memory_tool("old content", action="append")
myai.manage_memory_tool("new content only", action="replace")
content = open(myai.memory_path).read()
assert "new content only" in content
assert "old content" not in content
def test_manage_memory_empty_content(self, myai):
result = myai.manage_memory_tool("", action="append")
assert "error" in result.lower() or "Error" in result
# =========================================================================
# ask() with mock LLM tests
# =========================================================================
class TestAsk:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def test_ask_basic_response(self, myai, mock_litellm):
result = myai.ask("hello", stream=False)
assert "response" in result
assert "chat_history" in result
assert "usage" in result
assert result["response"] == "Test response from AI"
def test_ask_sticky_brain_engineer(self, myai, mock_litellm):
result = myai.ask("show me the routers", stream=False)
assert result["responder"] == "engineer"
def test_ask_explicit_architect(self, myai, mock_litellm):
result = myai.ask("architect: review the network design", stream=False)
assert result["responder"] == "architect"
def test_ask_returns_usage(self, myai, mock_litellm):
result = myai.ask("test", stream=False)
assert result["usage"]["total"] > 0
def test_ask_with_chat_history(self, myai, mock_litellm):
history = [
{"role": "user", "content": "previous question"},
{"role": "assistant", "content": "previous answer"}
]
result = myai.ask("follow up", chat_history=history, stream=False)
assert result["response"] is not None
# =========================================================================
# _get_engineer_tools / _get_architect_tools tests
# =========================================================================
class TestToolDefinitions:
@pytest.fixture
def myai(self, ai_config, mock_litellm):
from connpy.ai import ai
return ai(ai_config)
def test_engineer_tools_include_core(self, myai):
tools = myai._get_engineer_tools()
names = [t["function"]["name"] for t in tools]
assert "list_nodes" in names
assert "run_commands" in names
assert "get_node_info" in names
assert "consult_architect" in names
assert "escalate_to_architect" in names
def test_engineer_tools_include_external(self, myai):
myai.external_engineer_tools.append({
"type": "function",
"function": {"name": "custom_tool", "description": "test", "parameters": {}}
})
tools = myai._get_engineer_tools()
names = [t["function"]["name"] for t in tools]
assert "custom_tool" in names
def test_architect_tools_include_core(self, myai):
tools = myai._get_architect_tools()
names = [t["function"]["name"] for t in tools]
assert "delegate_to_engineer" in names
assert "return_to_engineer" in names
assert "manage_memory_tool" in names
def test_architect_tools_include_external(self, myai):
myai.external_architect_tools.append({
"type": "function",
"function": {"name": "arch_tool", "description": "test", "parameters": {}}
})
tools = myai._get_architect_tools()
names = [t["function"]["name"] for t in tools]
assert "arch_tool" in names
+268
View File
@@ -0,0 +1,268 @@
"""Tests for connpy.api module — Flask routes."""
import json
import pytest
from unittest.mock import patch, MagicMock
@pytest.fixture
def api_client(populated_config):
"""Create a Flask test client with a populated config."""
from connpy.api import app
app.custom_config = populated_config
app.config["TESTING"] = True
with app.test_client() as client:
yield client
# =========================================================================
# Root endpoint
# =========================================================================
class TestRootEndpoint:
def test_root_returns_welcome(self, api_client):
response = api_client.get("/")
data = response.get_json()
assert response.status_code == 200
assert "Welcome" in data["message"]
assert "version" in data
# =========================================================================
# /list_nodes endpoint
# =========================================================================
class TestListNodes:
def test_list_nodes_no_filter(self, api_client):
response = api_client.post("/list_nodes", json={})
data = response.get_json()
assert response.status_code == 200
assert isinstance(data, list)
assert "router1" in data
def test_list_nodes_with_filter(self, api_client):
response = api_client.post("/list_nodes", json={"filter": "router.*"})
data = response.get_json()
assert "router1" in data
assert all("router" in n or "Router" in n for n in data)
def test_list_nodes_case_insensitive(self, api_client):
"""Filter is lowercased when case=false."""
response = api_client.post("/list_nodes", json={"filter": "ROUTER.*"})
data = response.get_json()
# Should still match since the filter gets lowercased
assert isinstance(data, list)
def test_list_nodes_no_body(self, api_client):
"""No body returns all nodes."""
response = api_client.post("/list_nodes",
data="",
content_type="application/json")
data = response.get_json()
assert isinstance(data, list)
# =========================================================================
# /get_nodes endpoint
# =========================================================================
class TestGetNodes:
def test_get_nodes_no_filter(self, api_client):
response = api_client.post("/get_nodes", json={})
data = response.get_json()
assert response.status_code == 200
assert isinstance(data, dict)
assert "router1" in data
def test_get_nodes_with_filter(self, api_client):
response = api_client.post("/get_nodes", json={"filter": "router.*"})
data = response.get_json()
assert "router1" in data
assert "host" in data["router1"]
def test_get_nodes_has_attributes(self, api_client):
response = api_client.post("/get_nodes", json={"filter": "router1"})
data = response.get_json()
if "router1" in data:
assert "host" in data["router1"]
assert "protocol" in data["router1"]
# =========================================================================
# /run_commands endpoint
# =========================================================================
class TestRunCommands:
def test_missing_action(self, api_client):
response = api_client.post("/run_commands", json={
"nodes": "router1",
"commands": ["show version"]
})
data = response.get_json()
assert "DataError" in data
assert "action" in data["DataError"]
def test_missing_nodes(self, api_client):
response = api_client.post("/run_commands", json={
"action": "run",
"commands": ["show version"]
})
data = response.get_json()
assert "DataError" in data
assert "nodes" in data["DataError"]
def test_missing_commands(self, api_client):
response = api_client.post("/run_commands", json={
"action": "run",
"nodes": "router1"
})
data = response.get_json()
assert "DataError" in data
assert "commands" in data["DataError"]
def test_wrong_action(self, api_client):
response = api_client.post("/run_commands", json={
"action": "invalid",
"nodes": "router1",
"commands": ["show version"]
})
data = response.get_json()
assert "DataError" in data
assert "Wrong action" in data["DataError"]
@patch("connpy.api.nodes")
def test_run_action(self, mock_nodes_cls, api_client):
"""action=run executes and returns output."""
mock_instance = MagicMock()
mock_instance.run.return_value = {"router1": "Router v1.0"}
mock_nodes_cls.return_value = mock_instance
response = api_client.post("/run_commands", json={
"action": "run",
"nodes": "router1",
"commands": ["show version"]
})
data = response.get_json()
assert "router1" in data
@patch("connpy.api.nodes")
def test_test_action(self, mock_nodes_cls, api_client):
"""action=test returns result + output."""
mock_instance = MagicMock()
mock_instance.test.return_value = {"router1": {"expected": True}}
mock_instance.output = {"router1": "output text"}
mock_nodes_cls.return_value = mock_instance
response = api_client.post("/run_commands", json={
"action": "test",
"nodes": "router1",
"commands": ["show version"],
"expected": "Router"
})
data = response.get_json()
assert "result" in data
assert "output" in data
@patch("connpy.api.nodes")
def test_run_with_options(self, mock_nodes_cls, api_client):
"""Options get passed through."""
mock_instance = MagicMock()
mock_instance.run.return_value = {"router1": "ok"}
mock_nodes_cls.return_value = mock_instance
response = api_client.post("/run_commands", json={
"action": "run",
"nodes": "router1",
"commands": ["show version"],
"options": {"timeout": 30, "parallel": 5}
})
assert response.status_code == 200
@patch("connpy.api.nodes")
def test_run_folder_nodes(self, mock_nodes_cls, api_client):
"""Nodes with @ prefix are resolved as folders."""
mock_instance = MagicMock()
mock_instance.run.return_value = {"server1@office": "ok"}
mock_nodes_cls.return_value = mock_instance
response = api_client.post("/run_commands", json={
"action": "run",
"nodes": "@office",
"commands": ["ls -la"]
})
assert response.status_code == 200
@patch("connpy.api.nodes")
def test_run_list_nodes(self, mock_nodes_cls, api_client):
"""List of nodes is resolved correctly."""
mock_instance = MagicMock()
mock_instance.run.return_value = {"router1": "ok", "server1@office": "ok"}
mock_nodes_cls.return_value = mock_instance
response = api_client.post("/run_commands", json={
"action": "run",
"nodes": ["router1", "server1@office"],
"commands": ["show version"]
})
assert response.status_code == 200
# =========================================================================
# /ask_ai endpoint
# =========================================================================
class TestAskAI:
@patch("connpy.api.myai")
def test_ask_ai(self, mock_ai_cls, api_client):
mock_instance = MagicMock()
mock_instance.ask.return_value = {"response": "AI says hello"}
mock_ai_cls.return_value = mock_instance
response = api_client.post("/ask_ai", json={
"input": "list my routers"
})
data = response.get_json()
assert data is not None
@patch("connpy.api.myai")
def test_ask_ai_with_dryrun(self, mock_ai_cls, api_client):
mock_instance = MagicMock()
mock_instance.ask.return_value = {"response": "dry run"}
mock_ai_cls.return_value = mock_instance
response = api_client.post("/ask_ai", json={
"input": "test",
"dryrun": True
})
assert response.status_code == 200
@patch("connpy.api.myai")
def test_ask_ai_with_history(self, mock_ai_cls, api_client):
mock_instance = MagicMock()
mock_instance.ask.return_value = {"response": "with history"}
mock_ai_cls.return_value = mock_instance
response = api_client.post("/ask_ai", json={
"input": "follow up",
"chat_history": [
{"role": "user", "content": "previous"},
{"role": "assistant", "content": "answer"}
]
})
assert response.status_code == 200
# =========================================================================
# /confirm endpoint
# =========================================================================
class TestConfirm:
@patch("connpy.api.myai")
def test_confirm(self, mock_ai_cls, api_client):
mock_instance = MagicMock()
mock_instance.confirm.return_value = True
mock_ai_cls.return_value = mock_instance
response = api_client.post("/confirm", json={
"input": "yes"
})
assert response.status_code == 200
+182
View File
@@ -0,0 +1,182 @@
"""Tests for connpy.completion module."""
import os
import json
import pytest
from connpy.completion import _getallnodes, _getallfolders, _getcwd, _get_plugins
# =========================================================================
# _getallnodes tests
# =========================================================================
class TestGetAllNodes:
def test_flat_nodes(self):
"""Nodes without folders."""
config = {
"connections": {
"router1": {"type": "connection"},
"router2": {"type": "connection"}
}
}
nodes = _getallnodes(config)
assert "router1" in nodes
assert "router2" in nodes
def test_nested_nodes(self):
"""Nodes in folders and subfolders have correct format."""
config = {
"connections": {
"router1": {"type": "connection"},
"office": {
"type": "folder",
"server1": {"type": "connection"},
"datacenter": {
"type": "subfolder",
"db1": {"type": "connection"}
}
}
}
}
nodes = _getallnodes(config)
assert "router1" in nodes
assert "server1@office" in nodes
assert "db1@datacenter@office" in nodes
def test_empty_connections(self):
config = {"connections": {}}
nodes = _getallnodes(config)
assert nodes == []
# =========================================================================
# _getallfolders tests
# =========================================================================
class TestGetAllFolders:
def test_basic_folders(self):
config = {
"connections": {
"office": {"type": "folder"},
"home": {"type": "folder"}
}
}
folders = _getallfolders(config)
assert "@office" in folders
assert "@home" in folders
def test_with_subfolders(self):
config = {
"connections": {
"office": {
"type": "folder",
"datacenter": {"type": "subfolder"},
"server1": {"type": "connection"}
}
}
}
folders = _getallfolders(config)
assert "@office" in folders
assert "@datacenter@office" in folders
def test_empty(self):
config = {"connections": {}}
folders = _getallfolders(config)
assert folders == []
# =========================================================================
# _getcwd tests
# =========================================================================
class TestGetCwd:
def test_current_dir(self, tmp_path, monkeypatch):
"""Lists files in current directory."""
monkeypatch.chdir(tmp_path)
(tmp_path / "file1.txt").touch()
(tmp_path / "file2.py").touch()
subdir = tmp_path / "subdir"
subdir.mkdir()
result = _getcwd(["run", "run"], "run")
# Should list files
assert any("file1.txt" in r for r in result)
assert any("subdir/" in r for r in result)
def test_specific_path(self, tmp_path, monkeypatch):
"""Lists files matching a partial path."""
monkeypatch.chdir(tmp_path)
(tmp_path / "script.yaml").touch()
(tmp_path / "script2.yaml").touch()
result = _getcwd(["run", "script"], "run")
assert any("script" in r for r in result)
def test_folder_only(self, tmp_path, monkeypatch):
"""folderonly=True returns only directories."""
monkeypatch.chdir(tmp_path)
(tmp_path / "file.txt").touch()
subdir = tmp_path / "mydir"
subdir.mkdir()
result = _getcwd(["export", "export"], "export", folderonly=True)
files_in_result = [r for r in result if "file.txt" in r]
assert len(files_in_result) == 0
dirs_in_result = [r for r in result if "mydir" in r]
assert len(dirs_in_result) > 0
# =========================================================================
# _get_plugins tests
# =========================================================================
class TestGetPlugins:
def test_get_plugins_disable(self, tmp_path):
"""--disable returns enabled plugins."""
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
(plugin_dir / "active.py").touch()
(plugin_dir / "disabled.py.bkp").touch()
result = _get_plugins("--disable", str(tmp_path))
assert "active" in result
assert "disabled" not in result
def test_get_plugins_enable(self, tmp_path):
"""--enable returns disabled plugins."""
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
(plugin_dir / "active.py").touch()
(plugin_dir / "disabled.py.bkp").touch()
result = _get_plugins("--enable", str(tmp_path))
assert "disabled" in result
assert "active" not in result
def test_get_plugins_del(self, tmp_path):
"""--del returns all plugins."""
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
(plugin_dir / "active.py").touch()
(plugin_dir / "disabled.py.bkp").touch()
result = _get_plugins("--del", str(tmp_path))
assert "active" in result
assert "disabled" in result
def test_get_plugins_all(self, tmp_path):
"""'all' returns dict with paths."""
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
(plugin_dir / "myplugin.py").touch()
result = _get_plugins("all", str(tmp_path))
assert isinstance(result, dict)
assert "myplugin" in result
def test_get_plugins_empty_dir(self, tmp_path):
"""Empty plugins directory returns empty list."""
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
result = _get_plugins("--disable", str(tmp_path))
assert result == []
+376
View File
@@ -0,0 +1,376 @@
"""Tests for connpy.configfile module."""
import json
import os
import re
import pytest
from copy import deepcopy
class TestConfigfileInit:
def test_creates_default_config(self, tmp_config_dir):
"""Creates config.json with defaults when it doesn't exist."""
config_file = tmp_config_dir / "config.json"
config_file.unlink() # Remove existing
key_file = tmp_config_dir / ".osk"
from connpy.configfile import configfile
conf = configfile(conf=str(config_file), key=str(key_file))
assert config_file.exists()
assert conf.config["case"] == False
assert conf.config["idletime"] == 30
assert "default" in conf.profiles
def test_creates_rsa_key(self, tmp_config_dir):
"""Generates RSA key when it doesn't exist."""
key_file = tmp_config_dir / ".osk"
key_file.unlink() # Remove existing
from connpy.configfile import configfile
conf = configfile(conf=str(tmp_config_dir / "config.json"), key=str(key_file))
assert key_file.exists()
assert conf.privatekey is not None
assert conf.publickey is not None
def test_loads_existing_config(self, config):
"""Loads correctly from existing config."""
assert config.config is not None
assert config.connections is not None
assert config.profiles is not None
def test_config_file_permissions(self, tmp_config_dir):
"""Config is created with 0o600 permissions."""
config_file = tmp_config_dir / "config.json"
config_file.unlink()
from connpy.configfile import configfile
configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
stat = os.stat(str(config_file))
assert oct(stat.st_mode & 0o777) == oct(0o600)
def test_custom_paths(self, tmp_path):
"""Accepts custom paths for conf and key."""
config_dir = tmp_path / "custom"
config_dir.mkdir()
(config_dir / "plugins").mkdir()
# Write .folder for the config dir
dot_folder = tmp_path / ".config" / "conn"
dot_folder.mkdir(parents=True, exist_ok=True)
(dot_folder / ".folder").write_text(str(config_dir))
(dot_folder / "plugins").mkdir(exist_ok=True)
conf_path = str(config_dir / "my_config.json")
key_path = str(config_dir / "my_key")
from connpy.configfile import configfile
conf = configfile(conf=conf_path, key=key_path)
assert conf.file == conf_path
assert conf.key == key_path
class TestEncryption:
def test_encrypt_password(self, config):
"""Encrypts and produces b'...' format."""
encrypted = config.encrypt("mysecret")
assert encrypted.startswith("b'") or encrypted.startswith('b"')
def test_encrypt_decrypt_roundtrip(self, config):
"""Encrypt then decrypt returns original."""
from Crypto.PublicKey import RSA
from Crypto.Cipher import PKCS1_OAEP
import ast
original = "super_secret_password"
encrypted = config.encrypt(original)
# Decrypt
with open(config.key) as f:
key = RSA.import_key(f.read())
decryptor = PKCS1_OAEP.new(key)
decrypted = decryptor.decrypt(ast.literal_eval(encrypted)).decode("utf-8")
assert decrypted == original
class TestExplodeUnique:
def test_simple_node(self, config):
result = config._explode_unique("router1")
assert result == {"id": "router1"}
def test_node_with_folder(self, config):
result = config._explode_unique("r1@office")
assert result == {"id": "r1", "folder": "office"}
def test_node_with_subfolder(self, config):
result = config._explode_unique("r1@dc@office")
assert result == {"id": "r1", "folder": "office", "subfolder": "dc"}
def test_folder_only(self, config):
result = config._explode_unique("@office")
assert result == {"folder": "office"}
def test_subfolder_only(self, config):
result = config._explode_unique("@dc@office")
assert result == {"folder": "office", "subfolder": "dc"}
def test_too_deep(self, config):
result = config._explode_unique("a@b@c@d")
assert result == False
def test_empty_folder(self, config):
result = config._explode_unique("a@")
assert result == False
def test_empty_subfolder(self, config):
result = config._explode_unique("a@@office")
assert result == False
class TestCRUDNodes:
def test_add_node_root(self, config):
config._connections_add(
id="router1", host="10.0.0.1", protocol="ssh",
port="22", user="admin", password="pass", options="",
logs="", tags="", jumphost=""
)
assert "router1" in config.connections
assert config.connections["router1"]["host"] == "10.0.0.1"
def test_add_node_folder(self, config):
config._folder_add(folder="office")
config._connections_add(
id="server1", folder="office", host="10.0.1.1",
protocol="ssh", port="", user="root", password="pass",
options="", logs="", tags="", jumphost=""
)
assert "server1" in config.connections["office"]
def test_add_node_subfolder(self, config):
config._folder_add(folder="office")
config._folder_add(folder="office", subfolder="dc")
config._connections_add(
id="db1", folder="office", subfolder="dc", host="10.0.2.1",
protocol="ssh", port="", user="dbadmin", password="pass",
options="", logs="", tags="", jumphost=""
)
assert "db1" in config.connections["office"]["dc"]
def test_del_node_root(self, config):
config._connections_add(
id="router1", host="10.0.0.1", protocol="ssh",
port="", user="", password="", options="",
logs="", tags="", jumphost=""
)
config._connections_del(id="router1")
assert "router1" not in config.connections
def test_del_node_folder(self, config):
config._folder_add(folder="office")
config._connections_add(
id="server1", folder="office", host="10.0.1.1",
protocol="ssh", port="", user="", password="",
options="", logs="", tags="", jumphost=""
)
config._connections_del(id="server1", folder="office")
assert "server1" not in config.connections["office"]
def test_add_folder(self, config):
config._folder_add(folder="office")
assert "office" in config.connections
assert config.connections["office"]["type"] == "folder"
def test_add_subfolder(self, config):
config._folder_add(folder="office")
config._folder_add(folder="office", subfolder="dc")
assert "dc" in config.connections["office"]
assert config.connections["office"]["dc"]["type"] == "subfolder"
def test_del_folder(self, config):
config._folder_add(folder="office")
config._folder_del(folder="office")
assert "office" not in config.connections
def test_del_subfolder(self, config):
config._folder_add(folder="office")
config._folder_add(folder="office", subfolder="dc")
config._folder_del(folder="office", subfolder="dc")
assert "dc" not in config.connections["office"]
class TestCRUDProfiles:
def test_add_profile(self, config):
config._profiles_add(
id="myprofile", host="", protocol="telnet",
port="23", user="user1", password="pass1",
options="", logs="", tags="", jumphost=""
)
assert "myprofile" in config.profiles
assert config.profiles["myprofile"]["protocol"] == "telnet"
def test_del_profile(self, config):
config._profiles_add(
id="temp", host="", protocol="ssh", port="",
user="", password="", options="", logs="", tags="", jumphost=""
)
config._profiles_del(id="temp")
assert "temp" not in config.profiles
def test_default_profile_exists(self, config):
assert "default" in config.profiles
class TestGetItem:
def test_getitem_node(self, populated_config):
node = populated_config.getitem("router1")
assert node["host"] == "10.0.0.1"
assert "type" not in node # type is stripped
def test_getitem_folder(self, populated_config):
nodes = populated_config.getitem("@office")
# Should contain server1@office but NOT datacenter (subfolder)
assert "server1@office" in nodes
assert all("type" not in v for v in nodes.values())
def test_getitem_subfolder(self, populated_config):
nodes = populated_config.getitem("@datacenter@office")
assert "db1@datacenter@office" in nodes
def test_getitem_node_in_folder(self, populated_config):
node = populated_config.getitem("server1@office")
assert node["host"] == "10.0.1.1"
def test_getitem_node_in_subfolder(self, populated_config):
node = populated_config.getitem("db1@datacenter@office")
assert node["host"] == "10.0.2.1"
def test_getitem_with_profile_extraction(self, tmp_config_dir):
"""extract=True resolves @profile references."""
config_file = tmp_config_dir / "config.json"
data = {
"config": {"case": False, "idletime": 30, "fzf": False},
"connections": {
"router1": {
"host": "10.0.0.1", "protocol": "ssh", "port": "",
"user": "@office-user", "password": "@office-user",
"options": "", "logs": "", "tags": "", "jumphost": "",
"type": "connection"
}
},
"profiles": {
"default": {"host": "", "protocol": "ssh", "port": "",
"user": "", "password": "", "options": "",
"logs": "", "tags": "", "jumphost": ""},
"office-user": {"host": "", "protocol": "ssh", "port": "",
"user": "officeadmin", "password": "officepass",
"options": "", "logs": "", "tags": "", "jumphost": ""}
}
}
config_file.write_text(json.dumps(data, indent=4))
from connpy.configfile import configfile
conf = configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
node = conf.getitem("router1", extract=True)
assert node["user"] == "officeadmin"
assert node["password"] == "officepass"
def test_getitems_multiple(self, populated_config):
nodes = populated_config.getitems(["router1", "server1@office"])
assert "router1" in nodes
assert "server1@office" in nodes
def test_getitems_folder(self, populated_config):
nodes = populated_config.getitems(["@office"])
assert "server1@office" in nodes
class TestGetAll:
def test_getallnodes_no_filter(self, populated_config):
nodes = populated_config._getallnodes()
assert "router1" in nodes
assert "server1@office" in nodes
assert "db1@datacenter@office" in nodes
def test_getallnodes_string_filter(self, populated_config):
nodes = populated_config._getallnodes("router.*")
assert "router1" in nodes
assert "server1@office" not in nodes
def test_getallnodes_list_filter(self, populated_config):
nodes = populated_config._getallnodes(["router.*", "db.*"])
assert "router1" in nodes
assert "db1@datacenter@office" in nodes
assert "server1@office" not in nodes
def test_getallnodes_filter_invalid_type(self, populated_config):
with pytest.raises(ValueError):
populated_config._getallnodes(123)
def test_getallfolders(self, populated_config):
folders = populated_config._getallfolders()
assert "@office" in folders
assert "@datacenter@office" in folders
def test_getallnodesfull(self, populated_config):
nodes = populated_config._getallnodesfull()
assert "router1" in nodes
assert nodes["router1"]["host"] == "10.0.0.1"
def test_getallnodesfull_with_filter(self, populated_config):
nodes = populated_config._getallnodesfull("router.*")
assert "router1" in nodes
assert "server1@office" not in nodes
def test_profileused(self, tmp_config_dir):
"""Detects nodes using a specific profile."""
config_file = tmp_config_dir / "config.json"
data = {
"config": {"case": False, "idletime": 30, "fzf": False},
"connections": {
"router1": {
"host": "10.0.0.1", "protocol": "ssh", "port": "",
"user": "@myprofile", "password": "pass",
"options": "", "logs": "", "tags": "", "jumphost": "",
"type": "connection"
},
"router2": {
"host": "10.0.0.2", "protocol": "ssh", "port": "",
"user": "admin", "password": "pass",
"options": "", "logs": "", "tags": "", "jumphost": "",
"type": "connection"
}
},
"profiles": {
"default": {"host": "", "protocol": "ssh", "port": "",
"user": "", "password": "", "options": "",
"logs": "", "tags": "", "jumphost": ""},
"myprofile": {"host": "", "protocol": "ssh", "port": "",
"user": "profuser", "password": "profpass",
"options": "", "logs": "", "tags": "", "jumphost": ""}
}
}
config_file.write_text(json.dumps(data, indent=4))
from connpy.configfile import configfile
conf = configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk"))
used = conf._profileused("myprofile")
assert "router1" in used
assert "router2" not in used
def test_saveconfig(self, config):
"""Save and reload correctly."""
config._connections_add(
id="test_node", host="1.2.3.4", protocol="ssh",
port="", user="", password="", options="",
logs="", tags="", jumphost=""
)
result = config._saveconfig(config.file)
assert result == 0
# Reload and verify
from connpy.configfile import configfile
reloaded = configfile(conf=config.file, key=config.key)
assert "test_node" in reloaded.connections
+419
View File
@@ -0,0 +1,419 @@
"""Tests for connpy.core module — node and nodes classes."""
import json
import os
import io
import re
import pytest
from unittest.mock import patch, MagicMock, PropertyMock
from copy import deepcopy
# =========================================================================
# node.__init__ tests
# =========================================================================
class TestNodeInit:
def test_basic_init(self):
"""Creates node with basic attributes."""
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="pass1", protocol="ssh")
assert n.unique == "router1"
assert n.host == "10.0.0.1"
assert n.user == "admin"
assert n.protocol == "ssh"
assert n.password == ["pass1"]
def test_default_protocol(self):
"""Default protocol is ssh."""
from connpy.core import node
n = node("router1", "10.0.0.1")
assert n.protocol == "ssh"
def test_password_as_list_of_profiles(self, populated_config):
"""Password list with @profile references resolves correctly."""
from connpy.core import node
n = node("router1", "10.0.0.1", password=["@office-user"],
config=populated_config)
assert n.password == ["officepass"]
def test_password_plain_string(self):
"""Plain string password is wrapped in a list."""
from connpy.core import node
n = node("router1", "10.0.0.1", password="mypass")
assert n.password == ["mypass"]
def test_node_with_profile(self, populated_config):
"""Resolves @profile references for user."""
from connpy.core import node
n = node("test1", "10.0.0.1", user="@office-user", password="plain",
config=populated_config)
assert n.user == "officeadmin"
def test_node_tags(self):
"""Tags are stored correctly."""
from connpy.core import node
tags = {"os": "cisco_ios", "prompt": r"Router#"}
n = node("router1", "10.0.0.1", tags=tags)
assert n.tags["os"] == "cisco_ios"
# =========================================================================
# Command generation tests
# =========================================================================
class TestCommandGeneration:
def _make_node(self, **kwargs):
from connpy.core import node
defaults = {
"unique": "test", "host": "10.0.0.1", "protocol": "ssh",
"user": "admin", "password": "", "port": "", "options": "",
"jumphost": "", "tags": "", "logs": ""
}
defaults.update(kwargs)
return node(defaults.pop("unique"), defaults.pop("host"), **defaults)
def test_ssh_cmd_basic(self):
n = self._make_node()
cmd = n._get_cmd()
assert "ssh" in cmd
assert "admin@10.0.0.1" in cmd
def test_ssh_cmd_port(self):
n = self._make_node(port="2222")
cmd = n._get_cmd()
assert "-p 2222" in cmd
def test_ssh_cmd_options(self):
n = self._make_node(options="-o StrictHostKeyChecking=no")
cmd = n._get_cmd()
assert "-o StrictHostKeyChecking=no" in cmd
def test_sftp_cmd_port(self):
n = self._make_node(protocol="sftp", port="2222")
cmd = n._get_cmd()
assert "-P 2222" in cmd # SFTP uses uppercase P
def test_telnet_cmd(self):
n = self._make_node(protocol="telnet", port="23")
cmd = n._get_cmd()
assert "telnet 10.0.0.1" in cmd
assert "23" in cmd
def test_kubectl_cmd(self):
n = self._make_node(protocol="kubectl", host="my-pod", tags={"kube_command": "/bin/sh"})
cmd = n._get_cmd()
assert "kubectl exec" in cmd
assert "my-pod" in cmd
assert "/bin/sh" in cmd
def test_kubectl_cmd_default_command(self):
n = self._make_node(protocol="kubectl", host="my-pod")
cmd = n._get_cmd()
assert "/bin/bash" in cmd
def test_docker_cmd(self):
n = self._make_node(protocol="docker", host="my-container",
tags={"docker_command": "/bin/sh"})
cmd = n._get_cmd()
assert "docker" in cmd
assert "my-container" in cmd
assert "/bin/sh" in cmd
def test_invalid_protocol_raises(self):
n = self._make_node(protocol="invalid_proto")
with pytest.raises(ValueError, match="Invalid protocol"):
n._get_cmd()
def test_ssh_cmd_no_user(self):
n = self._make_node(user="")
cmd = n._get_cmd()
assert "10.0.0.1" in cmd
assert "@" not in cmd # No user@ prefix
# =========================================================================
# Password decryption tests
# =========================================================================
class TestPasswordDecryption:
def test_passtx_plaintext(self, config):
"""Plaintext passwords pass through unchanged."""
from connpy.core import node
n = node("test", "10.0.0.1", password="plainpass", config=config)
result = n._passtx(["plainpass"])
assert result == ["plainpass"]
def test_passtx_encrypted(self, config):
"""Encrypted passwords get decrypted."""
from connpy.core import node
encrypted = config.encrypt("mysecret")
n = node("test", "10.0.0.1", password=encrypted, config=config)
result = n._passtx([encrypted])
assert result == ["mysecret"]
def test_passtx_missing_key_raises(self):
"""Missing key file raises ValueError."""
from connpy.core import node
n = node("test", "10.0.0.1", password="pass")
# A password formatted as encrypted but no valid key
with pytest.raises((ValueError, Exception)):
n._passtx(["""b'corrupted_encrypted_data'"""], keyfile="/nonexistent")
# =========================================================================
# Log handling tests
# =========================================================================
class TestLogHandling:
def test_logfile_variable_substitution(self):
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", protocol="ssh", port="22",
logs="/logs/${unique}_${host}_${user}")
result = n._logfile()
assert result == "/logs/router1_10.0.0.1_admin"
def test_logfile_date_substitution(self):
from connpy.core import node
import datetime
n = node("router1", "10.0.0.1", logs="/logs/${date '%Y'}")
result = n._logfile()
assert datetime.datetime.now().strftime("%Y") in result
def test_logclean_removes_ansi(self):
from connpy.core import node
n = node("test", "10.0.0.1")
dirty = "\x1B[32mgreen text\x1B[0m"
clean = n._logclean(dirty, var=True)
assert "\x1B" not in clean
assert "green text" in clean
def test_logclean_removes_backspaces(self):
from connpy.core import node
n = node("test", "10.0.0.1")
dirty = "type\bo"
clean = n._logclean(dirty, var=True)
assert "\b" not in clean
# =========================================================================
# run() and test() with mock pexpect
# =========================================================================
class TestNodeRun:
def _make_connected_node(self, mock_pexpect_obj, **kwargs):
"""Create a node and mock its _connect to succeed."""
from connpy.core import node
defaults = {
"unique": "router1", "host": "10.0.0.1",
"protocol": "ssh", "user": "admin", "password": ""
}
defaults.update(kwargs)
n = node(defaults.pop("unique"), defaults.pop("host"), **defaults)
return n
def test_run_returns_output(self, mock_pexpect):
"""run() returns string output."""
child = mock_pexpect["child"]
pexp = mock_pexpect["pexpect"]
# Simulate: connect succeeds, command runs, prompt found
child.expect.return_value = 9 # prompt index for ssh
child.logfile_read = None
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="")
# Mock _connect to return True and set up child
with patch.object(n, '_connect', return_value=True):
n.child = child
log_buffer = io.BytesIO(b"show version\nRouter v1.0\nrouter#")
n.mylog = log_buffer
child.logfile_read = log_buffer
with patch.object(n, '_logclean', return_value="Router v1.0"):
output = n.run(["show version"])
assert n.status == 0
assert output == "Router v1.0"
def test_run_status_1_on_failure(self, mock_pexpect):
"""Status 1 when connection fails."""
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="")
with patch.object(n, '_connect', return_value="Connection failed code: 1\nrefused"):
output = n.run(["show version"])
assert n.status == 1
assert "refused" in output
def test_run_with_variables(self, mock_pexpect):
"""Variables get substituted in commands."""
child = mock_pexpect["child"]
child.expect.return_value = 9
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="")
sent_commands = []
child.sendline.side_effect = lambda cmd: sent_commands.append(cmd)
with patch.object(n, '_connect', return_value=True):
n.child = child
n.mylog = io.BytesIO(b"output")
with patch.object(n, '_logclean', return_value="output"):
n.run(["show ip route {subnet}"], vars={"subnet": "10.0.0.0/24"})
assert "show ip route 10.0.0.0/24" in sent_commands
def test_run_saves_to_folder(self, mock_pexpect, tmp_path):
"""folder param saves log file."""
child = mock_pexpect["child"]
child.expect.return_value = 9
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="")
with patch.object(n, '_connect', return_value=True):
n.child = child
n.mylog = io.BytesIO(b"log output")
with patch.object(n, '_logclean', return_value="log output"):
n.run(["show version"], folder=str(tmp_path))
log_files = list(tmp_path.glob("router1_*.txt"))
assert len(log_files) == 1
assert "log output" in log_files[0].read_text()
class TestNodeTest:
def test_test_returns_dict(self, mock_pexpect):
"""test() returns dict of results."""
child = mock_pexpect["child"]
child.expect.return_value = 0 # prompt found (index 0 in test expects)
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="")
with patch.object(n, '_connect', return_value=True):
n.child = child
n.mylog = io.BytesIO(b"1.1.1.1 is up")
with patch.object(n, '_logclean', return_value="1.1.1.1 is up"):
result = n.test(["ping 1.1.1.1"], "1.1.1.1")
assert isinstance(result, dict)
assert result.get("1.1.1.1") == True
def test_test_expected_not_found(self, mock_pexpect):
"""Expected text not found returns False."""
child = mock_pexpect["child"]
child.expect.return_value = 0
from connpy.core import node
n = node("router1", "10.0.0.1", user="admin", password="")
with patch.object(n, '_connect', return_value=True):
n.child = child
n.mylog = io.BytesIO(b"some other output")
with patch.object(n, '_logclean', return_value="some other output"):
result = n.test(["ping 1.1.1.1"], "1.1.1.1")
assert isinstance(result, dict)
assert result.get("1.1.1.1") == False
# =========================================================================
# nodes (parallel) tests
# =========================================================================
class TestNodes:
def test_nodes_init(self):
"""Creates list of node objects."""
from connpy.core import nodes
nodes_dict = {
"r1": {"host": "10.0.0.1", "user": "admin", "password": ""},
"r2": {"host": "10.0.0.2", "user": "admin", "password": ""}
}
mynodes = nodes(nodes_dict)
assert len(mynodes.nodelist) == 2
assert hasattr(mynodes, "r1")
assert hasattr(mynodes, "r2")
def test_nodes_run_parallel(self):
"""run() executes on all nodes and returns dict."""
from connpy.core import nodes
nodes_dict = {
"r1": {"host": "10.0.0.1", "user": "admin", "password": ""},
"r2": {"host": "10.0.0.2", "user": "admin", "password": ""}
}
mynodes = nodes(nodes_dict)
# Mock run on each node — must set output AND status on the node
for n in mynodes.nodelist:
original_node = n # capture by value
def make_mock(node_ref):
def mock_run(commands, **kwargs):
node_ref.output = f"output from {node_ref.unique}"
node_ref.status = 0
return mock_run
n.run = make_mock(n)
result = mynodes.run(["show version"])
assert "r1" in result
assert "r2" in result
def test_nodes_splitlist(self):
"""_splitlist divides list correctly."""
from connpy.core import nodes
mynodes = nodes({"r1": {"host": "1.1.1.1", "user": "", "password": ""}})
chunks = list(mynodes._splitlist([1, 2, 3, 4, 5], 2))
assert chunks == [[1, 2], [3, 4], [5]]
def test_nodes_run_with_vars(self):
"""Variables per node and __global__ work."""
from connpy.core import nodes
nodes_dict = {
"r1": {"host": "10.0.0.1", "user": "admin", "password": ""},
}
mynodes = nodes(nodes_dict)
captured_vars = {}
def mock_run(commands, vars=None, **kwargs):
captured_vars.update(vars or {})
mynodes.r1.output = "ok"
mynodes.r1.status = 0
mynodes.r1.run = mock_run
variables = {
"__global__": {"mask": "255.255.255.0"},
"r1": {"ip": "10.0.0.1"}
}
mynodes.run(["show ip"], vars=variables)
assert captured_vars.get("mask") == "255.255.255.0"
assert captured_vars.get("ip") == "10.0.0.1"
def test_nodes_on_complete_callback(self):
"""on_complete callback fires per node."""
from connpy.core import nodes
nodes_dict = {
"r1": {"host": "10.0.0.1", "user": "admin", "password": ""},
}
mynodes = nodes(nodes_dict)
completed = []
def mock_run(commands, **kwargs):
mynodes.r1.output = "done"
mynodes.r1.status = 0
mynodes.r1.run = mock_run
def on_done(unique, output, status):
completed.append(unique)
mynodes.run(["show version"], on_complete=on_done)
assert "r1" in completed
+216
View File
@@ -0,0 +1,216 @@
"""Tests for connpy.hooks module — MethodHook and ClassHook."""
import pytest
from connpy.hooks import MethodHook, ClassHook
# =========================================================================
# MethodHook Tests
# =========================================================================
class TestMethodHook:
def test_basic_call(self):
"""Decorated function executes normally."""
@MethodHook
def add(a, b):
return a + b
assert add(2, 3) == 5
def test_pre_hook_modifies_args(self):
"""Pre-hook can modify arguments before execution."""
@MethodHook
def greet(name):
return f"Hello {name}"
def uppercase_hook(name):
return (name.upper(),), {}
greet.register_pre_hook(uppercase_hook)
assert greet("world") == "Hello WORLD"
def test_post_hook_modifies_result(self):
"""Post-hook can modify the return value."""
@MethodHook
def compute(x):
return x * 2
def double_result(*args, **kwargs):
return kwargs["result"] * 2
compute.register_post_hook(double_result)
assert compute(5) == 20 # 5*2=10, then 10*2=20
def test_multiple_pre_hooks_order(self):
"""Pre-hooks execute in registration order."""
calls = []
@MethodHook
def func(x):
return x
def hook1(x):
calls.append("hook1")
return (x,), {}
def hook2(x):
calls.append("hook2")
return (x,), {}
func.register_pre_hook(hook1)
func.register_pre_hook(hook2)
func(1)
assert calls == ["hook1", "hook2"]
def test_multiple_post_hooks_order(self):
"""Post-hooks execute in registration order."""
calls = []
@MethodHook
def func(x):
return x
def hook1(*args, **kwargs):
calls.append("hook1")
return kwargs["result"]
def hook2(*args, **kwargs):
calls.append("hook2")
return kwargs["result"]
func.register_post_hook(hook1)
func.register_post_hook(hook2)
func(1)
assert calls == ["hook1", "hook2"]
def test_pre_hook_exception_continues(self, capsys):
"""If a pre-hook raises, the function still executes."""
@MethodHook
def func(x):
return x + 1
def bad_hook(x):
raise RuntimeError("broken hook")
func.register_pre_hook(bad_hook)
# Should not raise — the hook error is printed but execution continues
result = func(5)
assert result == 6
def test_post_hook_exception_continues(self, capsys):
"""If a post-hook raises, the result is still returned."""
@MethodHook
def func(x):
return x + 1
def bad_hook(*args, **kwargs):
raise RuntimeError("broken post hook")
func.register_post_hook(bad_hook)
result = func(5)
assert result == 6
def test_method_hook_as_instance_method(self):
"""MethodHook works as a descriptor on a class."""
class MyClass:
@MethodHook
def double(self, x):
return x * 2
obj = MyClass()
assert obj.double(5) == 10
def test_method_hook_instance_hook_registration(self):
"""Can register hooks via instance method access."""
class MyClass:
@MethodHook
def process(self, x):
return x
def add_ten(*args, **kwargs):
return kwargs["result"] + 10
obj = MyClass()
obj.process.register_post_hook(add_ten)
assert obj.process(5) == 15
# =========================================================================
# ClassHook Tests
# =========================================================================
class TestClassHook:
def test_creates_instance(self):
"""ClassHook still creates instances normally."""
@ClassHook
class MyClass:
def __init__(self, value):
self.value = value
obj = MyClass(42)
assert obj.value == 42
def test_modify_future_instances(self):
"""modify() affects all future instances."""
@ClassHook
class MyClass:
def __init__(self):
self.x = 1
def set_x_to_99(instance):
instance.x = 99
MyClass.modify(set_x_to_99)
obj = MyClass()
assert obj.x == 99
def test_modify_does_not_affect_past(self):
"""modify() does not affect already-created instances."""
@ClassHook
class MyClass:
def __init__(self):
self.x = 1
old_obj = MyClass()
def set_x_to_99(instance):
instance.x = 99
MyClass.modify(set_x_to_99)
assert old_obj.x == 1 # Not affected
assert MyClass().x == 99 # New instance IS affected
def test_instance_modify(self):
"""instance.modify() only affects that specific instance."""
@ClassHook
class MyClass:
def __init__(self):
self.x = 1
obj1 = MyClass()
obj2 = MyClass()
obj1.modify(lambda inst: setattr(inst, 'x', 999))
assert obj1.x == 999
assert obj2.x == 1
def test_multiple_deferred_hooks(self):
"""Multiple modify() calls apply in order."""
@ClassHook
class MyClass:
def __init__(self):
self.log = []
MyClass.modify(lambda inst: inst.log.append("first"))
MyClass.modify(lambda inst: inst.log.append("second"))
obj = MyClass()
assert obj.log == ["first", "second"]
def test_getattr_delegation(self):
"""ClassHook delegates attribute access to the wrapped class."""
@ClassHook
class MyClass:
class_var = "hello"
def __init__(self):
pass
assert MyClass.class_var == "hello"
+327
View File
@@ -0,0 +1,327 @@
"""Tests for connpy.plugins module."""
import os
import textwrap
import pytest
from connpy.plugins import Plugins
# ---------------------------------------------------------------------------
# Helper: write a plugin script to a file
# ---------------------------------------------------------------------------
def _write_plugin(path, code):
"""Write dedented code to a file."""
with open(path, "w") as f:
f.write(textwrap.dedent(code))
# =========================================================================
# verify_script tests
# =========================================================================
class TestVerifyScript:
def test_valid_parser_entrypoint(self, tmp_path):
p = tmp_path / "good.py"
_write_plugin(p, """\
import argparse
class Parser:
def __init__(self):
self.parser = argparse.ArgumentParser()
class Entrypoint:
def __init__(self, args, parser, connapp):
pass
""")
plugins = Plugins()
assert plugins.verify_script(str(p)) == False
def test_valid_preload_only(self, tmp_path):
p = tmp_path / "preload.py"
_write_plugin(p, """\
class Preload:
def __init__(self, connapp):
pass
""")
plugins = Plugins()
assert plugins.verify_script(str(p)) == False
def test_valid_all_three(self, tmp_path):
p = tmp_path / "all.py"
_write_plugin(p, """\
import argparse
class Parser:
def __init__(self):
self.parser = argparse.ArgumentParser()
class Entrypoint:
def __init__(self, args, parser, connapp):
pass
class Preload:
def __init__(self, connapp):
pass
""")
plugins = Plugins()
assert plugins.verify_script(str(p)) == False
def test_parser_without_entrypoint(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
import argparse
class Parser:
def __init__(self):
self.parser = argparse.ArgumentParser()
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result # Should be a truthy error string
assert "Entrypoint" in result
def test_entrypoint_without_parser(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
class Entrypoint:
def __init__(self, args, parser, connapp):
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "Parser" in result
def test_no_valid_class(self, tmp_path):
p = tmp_path / "empty.py"
_write_plugin(p, """\
def some_function():
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "No valid class" in result
def test_parser_missing_self_parser(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
class Parser:
def __init__(self):
self.something = "not parser"
class Entrypoint:
def __init__(self, args, parser, connapp):
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "self.parser" in result
def test_entrypoint_wrong_args(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
import argparse
class Parser:
def __init__(self):
self.parser = argparse.ArgumentParser()
class Entrypoint:
def __init__(self, args):
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "Entrypoint" in result
def test_preload_wrong_args(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
class Preload:
def __init__(self, connapp, extra):
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "Preload" in result
def test_disallowed_top_level(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
MY_GLOBAL = "not allowed"
class Preload:
def __init__(self, connapp):
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "not allowed" in result.lower() or "Plugin can only have" in result
def test_syntax_error(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
def broken(
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "Syntax error" in result
def test_if_name_main_allowed(self, tmp_path):
p = tmp_path / "good.py"
_write_plugin(p, """\
class Preload:
def __init__(self, connapp):
pass
if __name__ == "__main__":
print("standalone")
""")
plugins = Plugins()
assert plugins.verify_script(str(p)) == False
def test_other_if_not_allowed(self, tmp_path):
p = tmp_path / "bad.py"
_write_plugin(p, """\
import sys
if sys.platform == "linux":
pass
class Preload:
def __init__(self, connapp):
pass
""")
plugins = Plugins()
result = plugins.verify_script(str(p))
assert result
assert "__name__" in result
# =========================================================================
# Import and loading tests
# =========================================================================
class TestPluginLoading:
def test_import_from_path(self, tmp_path):
p = tmp_path / "mymod.py"
_write_plugin(p, """\
MY_VAR = 42
""")
plugins = Plugins()
module = plugins._import_from_path(str(p))
assert module.MY_VAR == 42
def test_import_plugins_to_argparse(self, tmp_path):
"""Valid plugins get loaded into argparse."""
import argparse
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
_write_plugin(plugin_dir / "myplugin.py", """\
import argparse
class Parser:
def __init__(self):
self.parser = argparse.ArgumentParser(description="My plugin")
class Entrypoint:
def __init__(self, args, parser, connapp):
pass
""")
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
plugins = Plugins()
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
assert "myplugin" in plugins.plugins
assert "myplugin" in plugins.plugin_parsers
def test_plugin_name_collision(self, tmp_path):
"""Plugin with same name as existing subcommand is skipped."""
import argparse
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
_write_plugin(plugin_dir / "existcmd.py", """\
import argparse
class Parser:
def __init__(self):
self.parser = argparse.ArgumentParser()
class Entrypoint:
def __init__(self, args, parser, connapp):
pass
""")
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
subparsers.add_parser("existcmd") # Already taken
plugins = Plugins()
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
assert "existcmd" not in plugins.plugins
def test_preload_registration(self, tmp_path):
"""Preload class gets registered in preloads dict."""
import argparse
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
_write_plugin(plugin_dir / "preloader.py", """\
class Preload:
def __init__(self, connapp):
pass
""")
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
plugins = Plugins()
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
assert "preloader" in plugins.preloads
def test_invalid_plugin_skipped(self, tmp_path, capsys):
"""Invalid plugin is skipped with error message."""
import argparse
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
_write_plugin(plugin_dir / "badplugin.py", """\
MY_GLOBAL = "bad"
""")
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
plugins = Plugins()
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
assert "badplugin" not in plugins.plugins
captured = capsys.readouterr()
assert "Failed to load plugin" in captured.err or "Failed to load plugin" in captured.out
def test_empty_directory(self, tmp_path):
"""Empty directory doesn't cause errors."""
import argparse
plugin_dir = tmp_path / "plugins"
plugin_dir.mkdir()
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
plugins = Plugins()
plugins._import_plugins_to_argparse(str(plugin_dir), subparsers)
assert len(plugins.plugins) == 0
+50
View File
@@ -0,0 +1,50 @@
"""Tests for connpy.printer module."""
import sys
from io import StringIO
from connpy import printer
class TestPrinter:
def test_info_output(self, capsys):
printer.info("hello world")
captured = capsys.readouterr()
assert "[i] hello world" in captured.out
def test_success_output(self, capsys):
printer.success("done")
captured = capsys.readouterr()
assert "[✓] done" in captured.out
def test_warning_output(self, capsys):
printer.warning("careful")
captured = capsys.readouterr()
assert "[!] careful" in captured.out
def test_error_output(self, capsys):
printer.error("failed")
captured = capsys.readouterr()
assert "[✗] failed" in captured.err
def test_debug_output(self, capsys):
printer.debug("debug info")
captured = capsys.readouterr()
assert "[d] debug info" in captured.out
def test_start_output(self, capsys):
printer.start("starting")
captured = capsys.readouterr()
assert "[+] starting" in captured.out
def test_custom_output(self, capsys):
printer.custom("TAG", "custom message")
captured = capsys.readouterr()
assert "[TAG] custom message" in captured.out
def test_multiline_indentation(self, capsys):
printer.info("line1\nline2\nline3")
captured = capsys.readouterr()
lines = captured.out.strip().split("\n")
assert lines[0] == "[i] line1"
# Second line should be indented by len("[i] ") = 4 chars
assert lines[1].startswith(" line2")
assert lines[2].startswith(" line3")