diff --git a/connpy/__init__.py b/connpy/__init__.py index d1b297f..ee992ac 100644 --- a/connpy/__init__.py +++ b/connpy/__init__.py @@ -542,7 +542,6 @@ from .api import * from .ai import ai from .plugins import Plugins from ._version import __version__ -from pkg_resources import get_distribution from . import printer __all__ = ["node", "nodes", "configfile", "connapp", "ai", "Plugins", "printer"] diff --git a/connpy/ai.py b/connpy/ai.py index 91aee29..8854ddc 100755 --- a/connpy/ai.py +++ b/connpy/ai.py @@ -396,7 +396,7 @@ class ai: if isinstance(commands, str): try: commands = json.loads(commands) - except: + except ValueError: commands = [c.strip() for c in commands.split('\n') if c.strip()] # Expand multi-line commands within a list (in case the AI packs them) @@ -795,9 +795,10 @@ class ai: response = completion(model=model, messages=safe_messages, tools=[], api_key=key) resp_msg = response.choices[0].message messages.append(resp_msg.model_dump(exclude_none=True)) - except: - pass - + except Exception as e: + if status: + status.update(f"[bold red]Error fetching summary: {e}[/bold red]") + printer.warning(f"Failed to fetch final summary from LLM: {e}") except KeyboardInterrupt: if status: status.update("[bold red]Interrupted! Closing pending tasks...") last_msg = messages[-1] @@ -810,7 +811,7 @@ class ai: response = completion(model=model, messages=safe_messages, tools=tools, api_key=key) resp_msg = response.choices[0].message messages.append(resp_msg.model_dump(exclude_none=True)) - except: pass + except Exception: pass finally: try: log_dir = self.config.defaultdir @@ -820,7 +821,7 @@ class ai: if os.path.exists(log_path): try: with open(log_path, "r") as f: hist = json.load(f) - except: hist = [] + except (IOError, json.JSONDecodeError): hist = [] hist.append({"timestamp": datetime.datetime.now().isoformat(), "roles": {"strategic_engine": self.architect_model, "execution_engine": self.engineer_model}, "session": messages}) with open(log_path, "w") as f: json.dump(hist[-10:], f, indent=4) except Exception as e: diff --git a/connpy/api.py b/connpy/api.py index 8260465..3436439 100755 --- a/connpy/api.py +++ b/connpy/api.py @@ -35,7 +35,7 @@ def list_nodes(): else: filter = filter.lower() output = conf._getallnodes(filter) - except: + except Exception: output = conf._getallnodes() return jsonify(output) @@ -52,7 +52,7 @@ def get_nodes(): else: filter = filter.lower() output = conf._getallnodesfull(filter) - except: + except Exception: output = conf._getallnodesfull() return jsonify(output) @@ -109,13 +109,13 @@ def run_commands(): mynodes = nodes(mynodes, config=conf) try: args["vars"] = data["vars"] - except: + except Exception: pass try: options = data["options"] thisoptions = {k: v for k, v in options.items() if k in ["prompt", "parallel", "timeout"]} args.update(thisoptions) - except: + except Exception: options = None if action == "run": output = mynodes.run(**args) @@ -136,20 +136,20 @@ def stop_api(): pid = int(f.readline().strip()) port = int(f.readline().strip()) PID_FILE=PID_FILE1 - except: + except (FileNotFoundError, ValueError, OSError): try: with open(PID_FILE2, "r") as f: pid = int(f.readline().strip()) port = int(f.readline().strip()) PID_FILE=PID_FILE2 - except: + except (FileNotFoundError, ValueError, OSError): printer.warning("Connpy API server is not running.") return # Send a SIGTERM signal to the process try: os.kill(pid, signal.SIGTERM) - except: - pass + except OSError as e: + printer.warning(f"Process kill failed (maybe already dead): {e}") # Delete the PID file os.remove(PID_FILE) printer.info(f"Server with process ID {pid} stopped.") @@ -177,11 +177,11 @@ def start_api(port=8048): try: with open(PID_FILE1, "w") as f: f.write(str(pid) + "\n" + str(port)) - except: + except OSError: try: with open(PID_FILE2, "w") as f: f.write(str(pid) + "\n" + str(port)) - except: + except OSError: printer.error("Couldn't create PID file.") exit(1) printer.start(f"Server is running with process ID {pid} on port {port}") diff --git a/connpy/completion.py b/connpy/completion.py index 8303ea1..448fcd5 100755 --- a/connpy/completion.py +++ b/connpy/completion.py @@ -95,7 +95,7 @@ def main(): try: with open(pathfile, "r") as f: configdir = f.read().strip() - except: + except (FileNotFoundError, IOError): configdir = defaultdir defaultfile = configdir + '/config.json' jsonconf = open(defaultfile) @@ -131,7 +131,7 @@ def main(): spec.loader.exec_module(module) plugin_completion = getattr(module, "_connpy_completion") strings = plugin_completion(wordsnumber, words, info) - except: + except Exception: exit() elif wordsnumber >= 3 and words[0] == "ai": if wordsnumber == 3: diff --git a/connpy/configfile.py b/connpy/configfile.py index cf82272..5a4a0e0 100755 --- a/connpy/configfile.py +++ b/connpy/configfile.py @@ -8,6 +8,7 @@ from Crypto.Cipher import PKCS1_OAEP from pathlib import Path from copy import deepcopy from .hooks import MethodHook, ClassHook +from . import printer @@ -60,7 +61,7 @@ class configfile: try: with open(pathfile, "r") as f: configdir = f.read().strip() - except: + except (FileNotFoundError, IOError): with open(pathfile, "w") as f: f.write(str(defaultdir)) configdir = defaultdir @@ -120,7 +121,8 @@ class configfile: with open(conf, "w") as f: json.dump(newconfig, f, indent = 4) f.close() - except: + except (IOError, OSError) as e: + printer.error(f"Failed to save config: {e}") return 1 return 0 @@ -205,12 +207,12 @@ class configfile: if profile: try: newfolder[node_name][key] = self.profiles[profile.group(1)][key] - except: + except KeyError: newfolder[node_name][key] = "" elif value == '' and key == "protocol": try: newfolder[node_name][key] = self.profiles["default"][key] - except: + except KeyError: newfolder[node_name][key] = "ssh" newfolder = {"{}{}".format(k,unique):v for k,v in newfolder.items()} @@ -231,12 +233,12 @@ class configfile: if profile: try: newnode[key] = self.profiles[profile.group(1)][key] - except: + except KeyError: newnode[key] = "" elif value == '' and key == "protocol": try: newnode[key] = self.profiles["default"][key] - except: + except KeyError: newnode[key] = "ssh" return newnode @@ -391,12 +393,12 @@ class configfile: if profile: try: nodes[node][key] = self.profiles[profile.group(1)][key] - except: + except KeyError: nodes[node][key] = "" elif value == '' and key == "protocol": try: nodes[node][key] = self.profiles["default"][key] - except: + except KeyError: nodes[node][key] = "ssh" return nodes diff --git a/connpy/connapp.py b/connpy/connapp.py index d208069..de97bc8 100755 --- a/connpy/connapp.py +++ b/connpy/connapp.py @@ -17,7 +17,6 @@ import shutil class NoAliasDumper(yaml.SafeDumper): def ignore_aliases(self, data): return True -import ast from rich.markdown import Markdown from rich.console import Console, Group from rich.panel import Panel @@ -29,7 +28,7 @@ mdprint = Console().print console = Console() try: from pyfzf.pyfzf import FzfPrompt -except: +except ImportError: FzfPrompt = None @@ -64,7 +63,7 @@ class connapp: self.case = self.config.config["case"] try: self.fzf = self.config.config["fzf"] - except: + except KeyError: self.fzf = False @@ -178,13 +177,13 @@ class connapp: try: core_path = os.path.dirname(os.path.realpath(__file__)) + "/core_plugins" self.plugins._import_plugins_to_argparse(core_path, subparsers) - except: - pass + except Exception as e: + printer.warning(e) try: file_path = self.config.defaultdir + "/plugins" self.plugins._import_plugins_to_argparse(file_path, subparsers) - except: - pass + except Exception as e: + printer.warning(e) for preload in self.plugins.preloads.values(): preload.Preload(self) #Generate helps @@ -826,7 +825,7 @@ class connapp: try: with open(args.data[0]) as file: imported = yaml.load(file, Loader=yaml.FullLoader) - except: + except Exception: printer.error("failed reading file {}".format(args.data[0])) exit(10) for k,v in imported.items(): @@ -1013,7 +1012,7 @@ class connapp: try: with open(args.data[0]) as file: scripts = yaml.load(file, Loader=yaml.FullLoader) - except: + except Exception: printer.error("failed reading file {}".format(args.data[0])) exit(10) for script in scripts["tasks"]: @@ -1053,13 +1052,13 @@ class connapp: options = script["options"] thisoptions = {k: v for k, v in options.items() if k in ["prompt", "parallel", "timeout"]} args.update(thisoptions) - except: + except KeyError: options = None try: size = str(os.get_terminal_size()) p = re.search(r'.*columns=([0-9]+)', size) columns = int(p.group(1)) - except: + except (ValueError, OSError): columns = 80 PANEL_WIDTH = columns @@ -1182,7 +1181,7 @@ class connapp: raise inquirer.errors.ValidationError("", reason="Pick a port between 1-65535, @profile o leave empty") try: port = int(current) - except: + except ValueError: port = 0 if current != "" and not 1 <= int(port) <= 65535: raise inquirer.errors.ValidationError("", reason="Pick a port between 1-65535 or leave empty") @@ -1194,7 +1193,7 @@ class connapp: raise inquirer.errors.ValidationError("", reason="Pick a port between 1-6553/app5, @profile or leave empty") try: port = int(current) - except: + except ValueError: port = 0 if current.startswith("@"): if current[1:] not in self.profiles: @@ -1220,7 +1219,7 @@ class connapp: isdict = False try: isdict = ast.literal_eval(current) - except: + except Exception: pass if not isinstance (isdict, dict): raise inquirer.errors.ValidationError("", reason="Tags should be a python dictionary.".format(current)) @@ -1232,7 +1231,7 @@ class connapp: isdict = False try: isdict = ast.literal_eval(current) - except: + except Exception: pass if not isinstance (isdict, dict): raise inquirer.errors.ValidationError("", reason="Tags should be a python dictionary.".format(current)) @@ -1316,7 +1315,7 @@ class connapp: defaults["tags"] = "" if "jumphost" not in defaults: defaults["jumphost"] = "" - except: + except KeyError: defaults = { "host":"", "protocol":"", "port":"", "user":"", "options":"", "logs":"" , "tags":"", "password":"", "jumphost":""} node = {} if edit == None: @@ -1390,7 +1389,7 @@ class connapp: defaults["tags"] = "" if "jumphost" not in defaults: defaults["jumphost"] = "" - except: + except KeyError: defaults = { "host":"", "protocol":"", "port":"", "user":"", "options":"", "logs":"", "tags": "", "jumphost": ""} profile = {} if edit == None: diff --git a/connpy/core.py b/connpy/core.py index 4ad32cd..f93c2ba 100755 --- a/connpy/core.py +++ b/connpy/core.py @@ -84,12 +84,12 @@ class node: if profile and config != '': try: setattr(self,key,config.profiles[profile.group(1)][key]) - except: + except KeyError: setattr(self,key,"") elif attr[key] == '' and key == "protocol": try: setattr(self,key,config.profiles["default"][key]) - except: + except (KeyError, AttributeError): setattr(self,key,"ssh") else: setattr(self,key,attr[key]) @@ -108,12 +108,12 @@ class node: if profile: try: self.jumphost[key] = config.profiles[profile.group(1)][key] - except: + except KeyError: self.jumphost[key] = "" elif self.jumphost[key] == '' and key == "protocol": try: self.jumphost[key] = config.profiles["default"][key] - except: + except KeyError: self.jumphost[key] = "ssh" if isinstance(self.jumphost["password"],list): jumphost_password = [] @@ -158,7 +158,7 @@ class node: try: decrypted = decryptor.decrypt(ast.literal_eval(passwd)).decode("utf-8") dpass.append(decrypted) - except: + except Exception: raise ValueError("Missing or corrupted key") return dpass diff --git a/connpy/core_plugins/capture.py b/connpy/core_plugins/capture.py index d001293..72a4aa0 100644 --- a/connpy/core_plugins/capture.py +++ b/connpy/core_plugins/capture.py @@ -176,7 +176,7 @@ class RemoteCapture: printer.success("Tcpdump finished capturing packets.") self.listener_active = False - except: + except Exception: pass def _sendline_until_connected(self, cmd, retries=5, interval=2): @@ -307,7 +307,7 @@ class RemoteCapture: try: self.fake_connection = True socket.create_connection(("localhost", self.local_port), timeout=1).close() - except: + except OSError: pass self.listener_active = False return @@ -324,7 +324,7 @@ class RemoteCapture: try: self.listener_conn.shutdown(socket.SHUT_RDWR) self.listener_conn.close() - except: + except OSError: pass if hasattr(self.node, "child"): self.node.child.close(force=True) diff --git a/connpy/core_plugins/sync.py b/connpy/core_plugins/sync.py index ccfcf5b..92eb040 100755 --- a/connpy/core_plugins/sync.py +++ b/connpy/core_plugins/sync.py @@ -28,7 +28,7 @@ class sync: self.connapp = connapp try: self.sync = self.connapp.config.config["sync"] - except: + except KeyError: self.sync = False def login(self): @@ -322,7 +322,7 @@ class sync: def config_listener_pre(self, *args, **kwargs): try: self.sync = self.connapp.config.config["sync"] - except: + except KeyError: self.sync = False return args, kwargs diff --git a/connpy/plugins.py b/connpy/plugins.py index 18d7bbf..81a95a2 100755 --- a/connpy/plugins.py +++ b/connpy/plugins.py @@ -62,8 +62,8 @@ class Plugins: if not (isinstance(node.test, ast.Compare) and isinstance(node.test.left, ast.Name) and node.test.left.id == '__name__' and - isinstance(node.test.comparators[0], ast.Str) and - node.test.comparators[0].s == '__main__'): + ((hasattr(ast, 'Str') and isinstance(node.test.comparators[0], getattr(ast, 'Str')) and node.test.comparators[0].s == '__main__') or + (hasattr(ast, 'Constant') and isinstance(node.test.comparators[0], getattr(ast, 'Constant')) and node.test.comparators[0].value == '__main__'))): return "Only __name__ == __main__ If is allowed" elif not isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Import, ast.ImportFrom, ast.Pass)): diff --git a/connpy/tests/__init__.py b/connpy/tests/__init__.py new file mode 100644 index 0000000..d4839a6 --- /dev/null +++ b/connpy/tests/__init__.py @@ -0,0 +1 @@ +# Tests package diff --git a/connpy/tests/conftest.py b/connpy/tests/conftest.py new file mode 100644 index 0000000..16dbaab --- /dev/null +++ b/connpy/tests/conftest.py @@ -0,0 +1,192 @@ +"""Shared fixtures for connpy tests. + +All tests use tmp_path to create isolated config/keys. +No test touches ~/.config/conn/ +""" +import pytest +import json +import os +from unittest.mock import patch, MagicMock +from Crypto.PublicKey import RSA + + +# --------------------------------------------------------------------------- +# Minimal config data +# --------------------------------------------------------------------------- +DEFAULT_CONFIG = { + "config": {"case": False, "idletime": 30, "fzf": False}, + "connections": {}, + "profiles": { + "default": { + "host": "", "protocol": "ssh", "port": "", "user": "", + "password": "", "options": "", "logs": "", "tags": "", "jumphost": "" + } + } +} + +SAMPLE_CONNECTIONS = { + "router1": { + "host": "10.0.0.1", "protocol": "ssh", "port": "22", + "user": "admin", "password": "pass1", "options": "", + "logs": "", "tags": "", "jumphost": "", "type": "connection" + }, + "office": { + "type": "folder", + "server1": { + "host": "10.0.1.1", "protocol": "ssh", "port": "", + "user": "root", "password": "pass2", "options": "", + "logs": "", "tags": "", "jumphost": "", "type": "connection" + }, + "datacenter": { + "type": "subfolder", + "db1": { + "host": "10.0.2.1", "protocol": "ssh", "port": "", + "user": "dbadmin", "password": "pass3", "options": "", + "logs": "", "tags": "", "jumphost": "", "type": "connection" + } + } + } +} + +SAMPLE_PROFILES = { + "default": { + "host": "", "protocol": "ssh", "port": "", "user": "", + "password": "", "options": "", "logs": "", "tags": "", "jumphost": "" + }, + "office-user": { + "host": "", "protocol": "ssh", "port": "", "user": "officeadmin", + "password": "officepass", "options": "", "logs": "", "tags": "", "jumphost": "" + } +} + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def tmp_config_dir(tmp_path): + """Create an isolated config directory with config.json and RSA key.""" + config_dir = tmp_path / ".config" / "conn" + config_dir.mkdir(parents=True) + plugins_dir = config_dir / "plugins" + plugins_dir.mkdir() + + # Write config.json + config_file = config_dir / "config.json" + config_file.write_text(json.dumps(DEFAULT_CONFIG, indent=4)) + os.chmod(str(config_file), 0o600) + + # Write .folder (points to itself) + folder_file = config_dir / ".folder" + folder_file.write_text(str(config_dir)) + + # Generate RSA key + key = RSA.generate(2048) + key_file = config_dir / ".osk" + key_file.write_bytes(key.export_key("PEM")) + os.chmod(str(key_file), 0o600) + + return config_dir + + +@pytest.fixture +def config(tmp_config_dir): + """Create a configfile instance pointing to tmp directory.""" + from connpy.configfile import configfile + conf_path = str(tmp_config_dir / "config.json") + key_path = str(tmp_config_dir / ".osk") + return configfile(conf=conf_path, key=key_path) + + +@pytest.fixture +def populated_config(tmp_config_dir): + """Create a configfile with sample nodes/profiles pre-loaded.""" + config_file = tmp_config_dir / "config.json" + data = { + "config": {"case": False, "idletime": 30, "fzf": False}, + "connections": SAMPLE_CONNECTIONS, + "profiles": SAMPLE_PROFILES + } + config_file.write_text(json.dumps(data, indent=4)) + from connpy.configfile import configfile + return configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk")) + + +@pytest.fixture +def mock_pexpect(): + """Mock pexpect.spawn for connection tests.""" + with patch("connpy.core.pexpect") as mock_pexp: + child = MagicMock() + child.before = b"" + child.after = b"router#" + child.readline.return_value = b"" + child.child_fd = 3 + mock_pexp.spawn.return_value = child + mock_pexp.EOF = object() + mock_pexp.TIMEOUT = object() + + # Also mock fdpexpect + with patch("connpy.core.fdpexpect", create=True) as mock_fd: + mock_fd.fdspawn.return_value = MagicMock() + yield { + "pexpect": mock_pexp, + "child": child, + "fdpexpect": mock_fd + } + + +@pytest.fixture +def mock_litellm(): + """Mock litellm.completion for AI tests.""" + with patch("connpy.ai.completion") as mock_comp: + # Create a default response + msg = MagicMock() + msg.content = "Test response from AI" + msg.tool_calls = None + msg.role = "assistant" + msg.model_dump.return_value = { + "role": "assistant", + "content": "Test response from AI" + } + + choice = MagicMock() + choice.message = msg + + response = MagicMock() + response.choices = [choice] + response.usage = MagicMock() + response.usage.prompt_tokens = 100 + response.usage.completion_tokens = 50 + response.usage.total_tokens = 150 + + mock_comp.return_value = response + + yield { + "completion": mock_comp, + "response": response, + "message": msg, + "choice": choice + } + + +@pytest.fixture +def ai_config(tmp_config_dir): + """Create a configfile with AI keys configured for AI tests.""" + config_file = tmp_config_dir / "config.json" + data = { + "config": { + "case": False, "idletime": 30, "fzf": False, + "ai": { + "engineer_model": "test/test-model", + "engineer_api_key": "test-engineer-key", + "architect_model": "test/test-architect", + "architect_api_key": "test-architect-key" + } + }, + "connections": SAMPLE_CONNECTIONS, + "profiles": SAMPLE_PROFILES + } + config_file.write_text(json.dumps(data, indent=4)) + from connpy.configfile import configfile + return configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk")) diff --git a/connpy/tests/test_ai.py b/connpy/tests/test_ai.py new file mode 100644 index 0000000..17e39e0 --- /dev/null +++ b/connpy/tests/test_ai.py @@ -0,0 +1,397 @@ +"""Tests for connpy.ai module.""" +import json +import os +import pytest +from unittest.mock import patch, MagicMock + + +# ========================================================================= +# AI Init tests +# ========================================================================= + +class TestAIInit: + def test_init_with_keys(self, ai_config, mock_litellm): + """Initializes correctly when keys are configured.""" + from connpy.ai import ai + myai = ai(ai_config) + assert myai.engineer_model == "test/test-model" + assert myai.architect_model == "test/test-architect" + + def test_init_missing_engineer_key(self, config): + """Raises ValueError if engineer key is missing.""" + from connpy.ai import ai + with pytest.raises(ValueError, match="Engineer API key"): + ai(config) + + def test_init_missing_architect_key_warns(self, ai_config, capsys, mock_litellm): + """Warns if architect key is missing but doesn't crash.""" + # Remove architect key + ai_config.config["ai"]["architect_api_key"] = None + from connpy.ai import ai + # Should not raise + myai = ai(ai_config) + assert myai.architect_key is None + + def test_default_models(self, config): + """Default models are set correctly when not configured.""" + config.config["ai"] = {"engineer_api_key": "test-key", "architect_api_key": "test-key"} + from connpy.ai import ai + myai = ai(config) + assert "gemini" in myai.engineer_model.lower() + assert "claude" in myai.architect_model.lower() or "anthropic" in myai.architect_model.lower() + + def test_init_loads_memory(self, ai_config, tmp_path, mock_litellm): + """Loads long-term memory from file if it exists.""" + memory_path = os.path.expanduser("~/.config/conn/ai_memory.md") + from connpy.ai import ai + + with patch("os.path.exists", side_effect=lambda p: True if p == memory_path else os.path.exists(p)): + with patch("builtins.open", side_effect=lambda f, *a, **kw: ( + __import__("io").StringIO("## Memory\nRouter1 is border router") + if f == memory_path else open(f, *a, **kw) + )): + try: + myai = ai(ai_config) + except Exception: + pass # May fail on other file opens, that's ok + + +# ========================================================================= +# register_ai_tool tests +# ========================================================================= + +class TestRegisterAITool: + @pytest.fixture + def myai(self, ai_config, mock_litellm): + from connpy.ai import ai + return ai(ai_config) + + def _make_tool_def(self, name="my_tool"): + return { + "type": "function", + "function": { + "name": name, + "description": "Test tool", + "parameters": {"type": "object", "properties": {}} + } + } + + def test_register_tool_engineer(self, myai): + tool_def = self._make_tool_def() + myai.register_ai_tool(tool_def, lambda self, **kw: "ok", target="engineer") + assert len(myai.external_engineer_tools) == 1 + assert len(myai.external_architect_tools) == 0 + + def test_register_tool_architect(self, myai): + tool_def = self._make_tool_def() + myai.register_ai_tool(tool_def, lambda self, **kw: "ok", target="architect") + assert len(myai.external_architect_tools) == 1 + assert len(myai.external_engineer_tools) == 0 + + def test_register_tool_both(self, myai): + tool_def = self._make_tool_def() + myai.register_ai_tool(tool_def, lambda self, **kw: "ok", target="both") + assert len(myai.external_engineer_tools) == 1 + assert len(myai.external_architect_tools) == 1 + + def test_register_tool_handler(self, myai): + tool_def = self._make_tool_def("custom_tool") + handler = lambda self, **kw: "result" + myai.register_ai_tool(tool_def, handler) + assert "custom_tool" in myai.external_tool_handlers + assert myai.external_tool_handlers["custom_tool"] is handler + + def test_register_tool_prompt_extension(self, myai): + tool_def = self._make_tool_def() + myai.register_ai_tool( + tool_def, lambda self, **kw: "ok", + engineer_prompt="- Custom capability", + architect_prompt=" * Custom tool" + ) + assert any("Custom capability" in ext for ext in myai.engineer_prompt_extensions) + assert any("Custom tool" in ext for ext in myai.architect_prompt_extensions) + + def test_register_tool_status_formatter(self, myai): + tool_def = self._make_tool_def("status_tool") + formatter = lambda args: f"[STATUS] {args}" + myai.register_ai_tool(tool_def, lambda self, **kw: "ok", status_formatter=formatter) + assert "status_tool" in myai.tool_status_formatters + + +# ========================================================================= +# Dynamic prompts tests +# ========================================================================= + +class TestDynamicPrompts: + @pytest.fixture + def myai(self, ai_config, mock_litellm): + from connpy.ai import ai + return ai(ai_config) + + def test_engineer_prompt_without_extensions(self, myai): + prompt = myai.engineer_system_prompt + assert "Plugin Capabilities" not in prompt + assert "TECHNICAL EXECUTION ENGINE" in prompt + + def test_engineer_prompt_with_extensions(self, myai): + myai.engineer_prompt_extensions.append("- AWS Cloud Auditing") + prompt = myai.engineer_system_prompt + assert "Plugin Capabilities" in prompt + assert "AWS Cloud Auditing" in prompt + + def test_architect_prompt_without_extensions(self, myai): + prompt = myai.architect_system_prompt + assert "Plugin Capabilities" not in prompt + assert "STRATEGIC REASONING ENGINE" in prompt + + def test_architect_prompt_with_extensions(self, myai): + myai.architect_prompt_extensions.append(" * Custom tool available") + prompt = myai.architect_system_prompt + assert "Plugin Capabilities" in prompt + assert "Custom tool available" in prompt + + + +# ========================================================================= +# _sanitize_messages tests +# ========================================================================= + +class TestSanitizeMessages: + @pytest.fixture + def myai(self, ai_config, mock_litellm): + from connpy.ai import ai + return ai(ai_config) + + def test_sanitize_empty(self, myai): + assert myai._sanitize_messages([]) == [] + + def test_sanitize_normal_messages(self, myai): + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"} + ] + result = myai._sanitize_messages(messages) + assert len(result) == 3 + + def test_sanitize_removes_orphan_tool_calls(self, myai): + """Tool calls at the end without responses are removed.""" + messages = [ + {"role": "user", "content": "do something"}, + {"role": "assistant", "content": None, "tool_calls": [ + {"id": "tc1", "function": {"name": "list_nodes", "arguments": "{}"}} + ]} + # No tool response follows! + ] + result = myai._sanitize_messages(messages) + assert len(result) == 1 # Only user message + assert result[0]["role"] == "user" + + def test_sanitize_removes_orphan_tool_responses(self, myai): + """Tool responses without preceding tool_calls are removed.""" + messages = [ + {"role": "user", "content": "hello"}, + {"role": "tool", "tool_call_id": "tc1", "name": "list_nodes", "content": "[]"} + ] + result = myai._sanitize_messages(messages) + assert len(result) == 1 + assert result[0]["role"] == "user" + + def test_sanitize_preserves_valid_tool_pairs(self, myai): + """Valid assistant+tool_calls followed by tool responses are preserved.""" + messages = [ + {"role": "user", "content": "list nodes"}, + {"role": "assistant", "content": None, "tool_calls": [ + {"id": "tc1", "function": {"name": "list_nodes", "arguments": "{}"}} + ]}, + {"role": "tool", "tool_call_id": "tc1", "name": "list_nodes", "content": "[\"r1\"]"}, + {"role": "assistant", "content": "Found r1"} + ] + result = myai._sanitize_messages(messages) + assert len(result) == 4 + + +# ========================================================================= +# _truncate tests +# ========================================================================= + +class TestTruncate: + @pytest.fixture + def myai(self, ai_config, mock_litellm): + from connpy.ai import ai + return ai(ai_config) + + def test_truncate_short_text(self, myai): + text = "short text" + assert myai._truncate(text) == text + + def test_truncate_long_text(self, myai): + text = "x" * 100000 + result = myai._truncate(text) + assert len(result) < 100000 + assert "[... OUTPUT TRUNCATED ...]" in result + + def test_truncate_custom_limit(self, myai): + text = "x" * 1000 + result = myai._truncate(text, limit=500) + assert len(result) < 1000 + assert "[... OUTPUT TRUNCATED ...]" in result + + def test_truncate_preserves_head_and_tail(self, myai): + text = "HEAD" + "x" * 100000 + "TAIL" + result = myai._truncate(text) + assert result.startswith("HEAD") + assert result.endswith("TAIL") + + +# ========================================================================= +# Tool methods tests +# ========================================================================= + +class TestToolMethods: + @pytest.fixture + def myai(self, ai_config, mock_litellm): + from connpy.ai import ai + return ai(ai_config) + + def test_list_nodes_tool_found(self, myai): + result = myai.list_nodes_tool("router.*") + parsed = json.loads(result) + assert "router1" in str(parsed) + + def test_list_nodes_tool_not_found(self, myai): + result = myai.list_nodes_tool("nonexistent_pattern_xyz") + assert "No nodes found" in result + + def test_get_node_info_masks_password(self, myai): + result = myai.get_node_info_tool("router1") + parsed = json.loads(result) + assert parsed["password"] == "***" + + def test_is_safe_command_show(self, myai): + assert myai._is_safe_command("show running-config") == True + assert myai._is_safe_command("show ip int brief") == True + + def test_is_safe_command_config(self, myai): + assert myai._is_safe_command("config t") == False + assert myai._is_safe_command("write memory") == False + + def test_is_safe_command_ls(self, myai): + assert myai._is_safe_command("ls -la") == True + + def test_is_safe_command_ping(self, myai): + assert myai._is_safe_command("ping 10.0.0.1") == True + + +# ========================================================================= +# manage_memory_tool tests +# ========================================================================= + +class TestManageMemory: + @pytest.fixture + def myai(self, ai_config, mock_litellm, tmp_path): + from connpy.ai import ai + myai = ai(ai_config) + myai.memory_path = str(tmp_path / "ai_memory.md") + return myai + + def test_manage_memory_append(self, myai): + result = myai.manage_memory_tool("Router1 is border router", action="append") + assert "successfully" in result.lower() + assert os.path.exists(myai.memory_path) + content = open(myai.memory_path).read() + assert "Router1 is border router" in content + + def test_manage_memory_replace(self, myai): + myai.manage_memory_tool("old content", action="append") + myai.manage_memory_tool("new content only", action="replace") + content = open(myai.memory_path).read() + assert "new content only" in content + assert "old content" not in content + + def test_manage_memory_empty_content(self, myai): + result = myai.manage_memory_tool("", action="append") + assert "error" in result.lower() or "Error" in result + + +# ========================================================================= +# ask() with mock LLM tests +# ========================================================================= + +class TestAsk: + @pytest.fixture + def myai(self, ai_config, mock_litellm): + from connpy.ai import ai + return ai(ai_config) + + def test_ask_basic_response(self, myai, mock_litellm): + result = myai.ask("hello", stream=False) + assert "response" in result + assert "chat_history" in result + assert "usage" in result + assert result["response"] == "Test response from AI" + + def test_ask_sticky_brain_engineer(self, myai, mock_litellm): + result = myai.ask("show me the routers", stream=False) + assert result["responder"] == "engineer" + + def test_ask_explicit_architect(self, myai, mock_litellm): + result = myai.ask("architect: review the network design", stream=False) + assert result["responder"] == "architect" + + def test_ask_returns_usage(self, myai, mock_litellm): + result = myai.ask("test", stream=False) + assert result["usage"]["total"] > 0 + + def test_ask_with_chat_history(self, myai, mock_litellm): + history = [ + {"role": "user", "content": "previous question"}, + {"role": "assistant", "content": "previous answer"} + ] + result = myai.ask("follow up", chat_history=history, stream=False) + assert result["response"] is not None + + +# ========================================================================= +# _get_engineer_tools / _get_architect_tools tests +# ========================================================================= + +class TestToolDefinitions: + @pytest.fixture + def myai(self, ai_config, mock_litellm): + from connpy.ai import ai + return ai(ai_config) + + def test_engineer_tools_include_core(self, myai): + tools = myai._get_engineer_tools() + names = [t["function"]["name"] for t in tools] + assert "list_nodes" in names + assert "run_commands" in names + assert "get_node_info" in names + assert "consult_architect" in names + assert "escalate_to_architect" in names + + def test_engineer_tools_include_external(self, myai): + myai.external_engineer_tools.append({ + "type": "function", + "function": {"name": "custom_tool", "description": "test", "parameters": {}} + }) + tools = myai._get_engineer_tools() + names = [t["function"]["name"] for t in tools] + assert "custom_tool" in names + + def test_architect_tools_include_core(self, myai): + tools = myai._get_architect_tools() + names = [t["function"]["name"] for t in tools] + assert "delegate_to_engineer" in names + assert "return_to_engineer" in names + assert "manage_memory_tool" in names + + def test_architect_tools_include_external(self, myai): + myai.external_architect_tools.append({ + "type": "function", + "function": {"name": "arch_tool", "description": "test", "parameters": {}} + }) + tools = myai._get_architect_tools() + names = [t["function"]["name"] for t in tools] + assert "arch_tool" in names diff --git a/connpy/tests/test_api.py b/connpy/tests/test_api.py new file mode 100644 index 0000000..0eebb45 --- /dev/null +++ b/connpy/tests/test_api.py @@ -0,0 +1,268 @@ +"""Tests for connpy.api module — Flask routes.""" +import json +import pytest +from unittest.mock import patch, MagicMock + + +@pytest.fixture +def api_client(populated_config): + """Create a Flask test client with a populated config.""" + from connpy.api import app + app.custom_config = populated_config + app.config["TESTING"] = True + with app.test_client() as client: + yield client + + +# ========================================================================= +# Root endpoint +# ========================================================================= + +class TestRootEndpoint: + def test_root_returns_welcome(self, api_client): + response = api_client.get("/") + data = response.get_json() + assert response.status_code == 200 + assert "Welcome" in data["message"] + assert "version" in data + + +# ========================================================================= +# /list_nodes endpoint +# ========================================================================= + +class TestListNodes: + def test_list_nodes_no_filter(self, api_client): + response = api_client.post("/list_nodes", json={}) + data = response.get_json() + assert response.status_code == 200 + assert isinstance(data, list) + assert "router1" in data + + def test_list_nodes_with_filter(self, api_client): + response = api_client.post("/list_nodes", json={"filter": "router.*"}) + data = response.get_json() + assert "router1" in data + assert all("router" in n or "Router" in n for n in data) + + def test_list_nodes_case_insensitive(self, api_client): + """Filter is lowercased when case=false.""" + response = api_client.post("/list_nodes", json={"filter": "ROUTER.*"}) + data = response.get_json() + # Should still match since the filter gets lowercased + assert isinstance(data, list) + + def test_list_nodes_no_body(self, api_client): + """No body returns all nodes.""" + response = api_client.post("/list_nodes", + data="", + content_type="application/json") + data = response.get_json() + assert isinstance(data, list) + + +# ========================================================================= +# /get_nodes endpoint +# ========================================================================= + +class TestGetNodes: + def test_get_nodes_no_filter(self, api_client): + response = api_client.post("/get_nodes", json={}) + data = response.get_json() + assert response.status_code == 200 + assert isinstance(data, dict) + assert "router1" in data + + def test_get_nodes_with_filter(self, api_client): + response = api_client.post("/get_nodes", json={"filter": "router.*"}) + data = response.get_json() + assert "router1" in data + assert "host" in data["router1"] + + def test_get_nodes_has_attributes(self, api_client): + response = api_client.post("/get_nodes", json={"filter": "router1"}) + data = response.get_json() + if "router1" in data: + assert "host" in data["router1"] + assert "protocol" in data["router1"] + + +# ========================================================================= +# /run_commands endpoint +# ========================================================================= + +class TestRunCommands: + def test_missing_action(self, api_client): + response = api_client.post("/run_commands", json={ + "nodes": "router1", + "commands": ["show version"] + }) + data = response.get_json() + assert "DataError" in data + assert "action" in data["DataError"] + + def test_missing_nodes(self, api_client): + response = api_client.post("/run_commands", json={ + "action": "run", + "commands": ["show version"] + }) + data = response.get_json() + assert "DataError" in data + assert "nodes" in data["DataError"] + + def test_missing_commands(self, api_client): + response = api_client.post("/run_commands", json={ + "action": "run", + "nodes": "router1" + }) + data = response.get_json() + assert "DataError" in data + assert "commands" in data["DataError"] + + def test_wrong_action(self, api_client): + response = api_client.post("/run_commands", json={ + "action": "invalid", + "nodes": "router1", + "commands": ["show version"] + }) + data = response.get_json() + assert "DataError" in data + assert "Wrong action" in data["DataError"] + + @patch("connpy.api.nodes") + def test_run_action(self, mock_nodes_cls, api_client): + """action=run executes and returns output.""" + mock_instance = MagicMock() + mock_instance.run.return_value = {"router1": "Router v1.0"} + mock_nodes_cls.return_value = mock_instance + + response = api_client.post("/run_commands", json={ + "action": "run", + "nodes": "router1", + "commands": ["show version"] + }) + data = response.get_json() + assert "router1" in data + + @patch("connpy.api.nodes") + def test_test_action(self, mock_nodes_cls, api_client): + """action=test returns result + output.""" + mock_instance = MagicMock() + mock_instance.test.return_value = {"router1": {"expected": True}} + mock_instance.output = {"router1": "output text"} + mock_nodes_cls.return_value = mock_instance + + response = api_client.post("/run_commands", json={ + "action": "test", + "nodes": "router1", + "commands": ["show version"], + "expected": "Router" + }) + data = response.get_json() + assert "result" in data + assert "output" in data + + @patch("connpy.api.nodes") + def test_run_with_options(self, mock_nodes_cls, api_client): + """Options get passed through.""" + mock_instance = MagicMock() + mock_instance.run.return_value = {"router1": "ok"} + mock_nodes_cls.return_value = mock_instance + + response = api_client.post("/run_commands", json={ + "action": "run", + "nodes": "router1", + "commands": ["show version"], + "options": {"timeout": 30, "parallel": 5} + }) + assert response.status_code == 200 + + @patch("connpy.api.nodes") + def test_run_folder_nodes(self, mock_nodes_cls, api_client): + """Nodes with @ prefix are resolved as folders.""" + mock_instance = MagicMock() + mock_instance.run.return_value = {"server1@office": "ok"} + mock_nodes_cls.return_value = mock_instance + + response = api_client.post("/run_commands", json={ + "action": "run", + "nodes": "@office", + "commands": ["ls -la"] + }) + assert response.status_code == 200 + + @patch("connpy.api.nodes") + def test_run_list_nodes(self, mock_nodes_cls, api_client): + """List of nodes is resolved correctly.""" + mock_instance = MagicMock() + mock_instance.run.return_value = {"router1": "ok", "server1@office": "ok"} + mock_nodes_cls.return_value = mock_instance + + response = api_client.post("/run_commands", json={ + "action": "run", + "nodes": ["router1", "server1@office"], + "commands": ["show version"] + }) + assert response.status_code == 200 + + +# ========================================================================= +# /ask_ai endpoint +# ========================================================================= + +class TestAskAI: + @patch("connpy.api.myai") + def test_ask_ai(self, mock_ai_cls, api_client): + mock_instance = MagicMock() + mock_instance.ask.return_value = {"response": "AI says hello"} + mock_ai_cls.return_value = mock_instance + + response = api_client.post("/ask_ai", json={ + "input": "list my routers" + }) + data = response.get_json() + assert data is not None + + @patch("connpy.api.myai") + def test_ask_ai_with_dryrun(self, mock_ai_cls, api_client): + mock_instance = MagicMock() + mock_instance.ask.return_value = {"response": "dry run"} + mock_ai_cls.return_value = mock_instance + + response = api_client.post("/ask_ai", json={ + "input": "test", + "dryrun": True + }) + assert response.status_code == 200 + + @patch("connpy.api.myai") + def test_ask_ai_with_history(self, mock_ai_cls, api_client): + mock_instance = MagicMock() + mock_instance.ask.return_value = {"response": "with history"} + mock_ai_cls.return_value = mock_instance + + response = api_client.post("/ask_ai", json={ + "input": "follow up", + "chat_history": [ + {"role": "user", "content": "previous"}, + {"role": "assistant", "content": "answer"} + ] + }) + assert response.status_code == 200 + + +# ========================================================================= +# /confirm endpoint +# ========================================================================= + +class TestConfirm: + @patch("connpy.api.myai") + def test_confirm(self, mock_ai_cls, api_client): + mock_instance = MagicMock() + mock_instance.confirm.return_value = True + mock_ai_cls.return_value = mock_instance + + response = api_client.post("/confirm", json={ + "input": "yes" + }) + assert response.status_code == 200 diff --git a/connpy/tests/test_completion.py b/connpy/tests/test_completion.py new file mode 100644 index 0000000..7d3ed29 --- /dev/null +++ b/connpy/tests/test_completion.py @@ -0,0 +1,182 @@ +"""Tests for connpy.completion module.""" +import os +import json +import pytest +from connpy.completion import _getallnodes, _getallfolders, _getcwd, _get_plugins + + +# ========================================================================= +# _getallnodes tests +# ========================================================================= + +class TestGetAllNodes: + def test_flat_nodes(self): + """Nodes without folders.""" + config = { + "connections": { + "router1": {"type": "connection"}, + "router2": {"type": "connection"} + } + } + nodes = _getallnodes(config) + assert "router1" in nodes + assert "router2" in nodes + + def test_nested_nodes(self): + """Nodes in folders and subfolders have correct format.""" + config = { + "connections": { + "router1": {"type": "connection"}, + "office": { + "type": "folder", + "server1": {"type": "connection"}, + "datacenter": { + "type": "subfolder", + "db1": {"type": "connection"} + } + } + } + } + nodes = _getallnodes(config) + assert "router1" in nodes + assert "server1@office" in nodes + assert "db1@datacenter@office" in nodes + + def test_empty_connections(self): + config = {"connections": {}} + nodes = _getallnodes(config) + assert nodes == [] + + +# ========================================================================= +# _getallfolders tests +# ========================================================================= + +class TestGetAllFolders: + def test_basic_folders(self): + config = { + "connections": { + "office": {"type": "folder"}, + "home": {"type": "folder"} + } + } + folders = _getallfolders(config) + assert "@office" in folders + assert "@home" in folders + + def test_with_subfolders(self): + config = { + "connections": { + "office": { + "type": "folder", + "datacenter": {"type": "subfolder"}, + "server1": {"type": "connection"} + } + } + } + folders = _getallfolders(config) + assert "@office" in folders + assert "@datacenter@office" in folders + + def test_empty(self): + config = {"connections": {}} + folders = _getallfolders(config) + assert folders == [] + + +# ========================================================================= +# _getcwd tests +# ========================================================================= + +class TestGetCwd: + def test_current_dir(self, tmp_path, monkeypatch): + """Lists files in current directory.""" + monkeypatch.chdir(tmp_path) + (tmp_path / "file1.txt").touch() + (tmp_path / "file2.py").touch() + subdir = tmp_path / "subdir" + subdir.mkdir() + + result = _getcwd(["run", "run"], "run") + # Should list files + assert any("file1.txt" in r for r in result) + assert any("subdir/" in r for r in result) + + def test_specific_path(self, tmp_path, monkeypatch): + """Lists files matching a partial path.""" + monkeypatch.chdir(tmp_path) + (tmp_path / "script.yaml").touch() + (tmp_path / "script2.yaml").touch() + + result = _getcwd(["run", "script"], "run") + assert any("script" in r for r in result) + + def test_folder_only(self, tmp_path, monkeypatch): + """folderonly=True returns only directories.""" + monkeypatch.chdir(tmp_path) + (tmp_path / "file.txt").touch() + subdir = tmp_path / "mydir" + subdir.mkdir() + + result = _getcwd(["export", "export"], "export", folderonly=True) + files_in_result = [r for r in result if "file.txt" in r] + assert len(files_in_result) == 0 + dirs_in_result = [r for r in result if "mydir" in r] + assert len(dirs_in_result) > 0 + + +# ========================================================================= +# _get_plugins tests +# ========================================================================= + +class TestGetPlugins: + def test_get_plugins_disable(self, tmp_path): + """--disable returns enabled plugins.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + (plugin_dir / "active.py").touch() + (plugin_dir / "disabled.py.bkp").touch() + + result = _get_plugins("--disable", str(tmp_path)) + assert "active" in result + assert "disabled" not in result + + def test_get_plugins_enable(self, tmp_path): + """--enable returns disabled plugins.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + (plugin_dir / "active.py").touch() + (plugin_dir / "disabled.py.bkp").touch() + + result = _get_plugins("--enable", str(tmp_path)) + assert "disabled" in result + assert "active" not in result + + def test_get_plugins_del(self, tmp_path): + """--del returns all plugins.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + (plugin_dir / "active.py").touch() + (plugin_dir / "disabled.py.bkp").touch() + + result = _get_plugins("--del", str(tmp_path)) + assert "active" in result + assert "disabled" in result + + def test_get_plugins_all(self, tmp_path): + """'all' returns dict with paths.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + (plugin_dir / "myplugin.py").touch() + + result = _get_plugins("all", str(tmp_path)) + assert isinstance(result, dict) + assert "myplugin" in result + + def test_get_plugins_empty_dir(self, tmp_path): + """Empty plugins directory returns empty list.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + + result = _get_plugins("--disable", str(tmp_path)) + assert result == [] diff --git a/connpy/tests/test_configfile.py b/connpy/tests/test_configfile.py new file mode 100644 index 0000000..6420bd7 --- /dev/null +++ b/connpy/tests/test_configfile.py @@ -0,0 +1,376 @@ +"""Tests for connpy.configfile module.""" +import json +import os +import re +import pytest +from copy import deepcopy + + +class TestConfigfileInit: + def test_creates_default_config(self, tmp_config_dir): + """Creates config.json with defaults when it doesn't exist.""" + config_file = tmp_config_dir / "config.json" + config_file.unlink() # Remove existing + key_file = tmp_config_dir / ".osk" + + from connpy.configfile import configfile + conf = configfile(conf=str(config_file), key=str(key_file)) + + assert config_file.exists() + assert conf.config["case"] == False + assert conf.config["idletime"] == 30 + assert "default" in conf.profiles + + def test_creates_rsa_key(self, tmp_config_dir): + """Generates RSA key when it doesn't exist.""" + key_file = tmp_config_dir / ".osk" + key_file.unlink() # Remove existing + + from connpy.configfile import configfile + conf = configfile(conf=str(tmp_config_dir / "config.json"), key=str(key_file)) + + assert key_file.exists() + assert conf.privatekey is not None + assert conf.publickey is not None + + def test_loads_existing_config(self, config): + """Loads correctly from existing config.""" + assert config.config is not None + assert config.connections is not None + assert config.profiles is not None + + def test_config_file_permissions(self, tmp_config_dir): + """Config is created with 0o600 permissions.""" + config_file = tmp_config_dir / "config.json" + config_file.unlink() + + from connpy.configfile import configfile + configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk")) + + stat = os.stat(str(config_file)) + assert oct(stat.st_mode & 0o777) == oct(0o600) + + def test_custom_paths(self, tmp_path): + """Accepts custom paths for conf and key.""" + config_dir = tmp_path / "custom" + config_dir.mkdir() + (config_dir / "plugins").mkdir() + + # Write .folder for the config dir + dot_folder = tmp_path / ".config" / "conn" + dot_folder.mkdir(parents=True, exist_ok=True) + (dot_folder / ".folder").write_text(str(config_dir)) + (dot_folder / "plugins").mkdir(exist_ok=True) + + conf_path = str(config_dir / "my_config.json") + key_path = str(config_dir / "my_key") + + from connpy.configfile import configfile + conf = configfile(conf=conf_path, key=key_path) + + assert conf.file == conf_path + assert conf.key == key_path + + +class TestEncryption: + def test_encrypt_password(self, config): + """Encrypts and produces b'...' format.""" + encrypted = config.encrypt("mysecret") + assert encrypted.startswith("b'") or encrypted.startswith('b"') + + def test_encrypt_decrypt_roundtrip(self, config): + """Encrypt then decrypt returns original.""" + from Crypto.PublicKey import RSA + from Crypto.Cipher import PKCS1_OAEP + import ast + + original = "super_secret_password" + encrypted = config.encrypt(original) + + # Decrypt + with open(config.key) as f: + key = RSA.import_key(f.read()) + decryptor = PKCS1_OAEP.new(key) + decrypted = decryptor.decrypt(ast.literal_eval(encrypted)).decode("utf-8") + assert decrypted == original + + +class TestExplodeUnique: + def test_simple_node(self, config): + result = config._explode_unique("router1") + assert result == {"id": "router1"} + + def test_node_with_folder(self, config): + result = config._explode_unique("r1@office") + assert result == {"id": "r1", "folder": "office"} + + def test_node_with_subfolder(self, config): + result = config._explode_unique("r1@dc@office") + assert result == {"id": "r1", "folder": "office", "subfolder": "dc"} + + def test_folder_only(self, config): + result = config._explode_unique("@office") + assert result == {"folder": "office"} + + def test_subfolder_only(self, config): + result = config._explode_unique("@dc@office") + assert result == {"folder": "office", "subfolder": "dc"} + + def test_too_deep(self, config): + result = config._explode_unique("a@b@c@d") + assert result == False + + def test_empty_folder(self, config): + result = config._explode_unique("a@") + assert result == False + + def test_empty_subfolder(self, config): + result = config._explode_unique("a@@office") + assert result == False + + +class TestCRUDNodes: + def test_add_node_root(self, config): + config._connections_add( + id="router1", host="10.0.0.1", protocol="ssh", + port="22", user="admin", password="pass", options="", + logs="", tags="", jumphost="" + ) + assert "router1" in config.connections + assert config.connections["router1"]["host"] == "10.0.0.1" + + def test_add_node_folder(self, config): + config._folder_add(folder="office") + config._connections_add( + id="server1", folder="office", host="10.0.1.1", + protocol="ssh", port="", user="root", password="pass", + options="", logs="", tags="", jumphost="" + ) + assert "server1" in config.connections["office"] + + def test_add_node_subfolder(self, config): + config._folder_add(folder="office") + config._folder_add(folder="office", subfolder="dc") + config._connections_add( + id="db1", folder="office", subfolder="dc", host="10.0.2.1", + protocol="ssh", port="", user="dbadmin", password="pass", + options="", logs="", tags="", jumphost="" + ) + assert "db1" in config.connections["office"]["dc"] + + def test_del_node_root(self, config): + config._connections_add( + id="router1", host="10.0.0.1", protocol="ssh", + port="", user="", password="", options="", + logs="", tags="", jumphost="" + ) + config._connections_del(id="router1") + assert "router1" not in config.connections + + def test_del_node_folder(self, config): + config._folder_add(folder="office") + config._connections_add( + id="server1", folder="office", host="10.0.1.1", + protocol="ssh", port="", user="", password="", + options="", logs="", tags="", jumphost="" + ) + config._connections_del(id="server1", folder="office") + assert "server1" not in config.connections["office"] + + def test_add_folder(self, config): + config._folder_add(folder="office") + assert "office" in config.connections + assert config.connections["office"]["type"] == "folder" + + def test_add_subfolder(self, config): + config._folder_add(folder="office") + config._folder_add(folder="office", subfolder="dc") + assert "dc" in config.connections["office"] + assert config.connections["office"]["dc"]["type"] == "subfolder" + + def test_del_folder(self, config): + config._folder_add(folder="office") + config._folder_del(folder="office") + assert "office" not in config.connections + + def test_del_subfolder(self, config): + config._folder_add(folder="office") + config._folder_add(folder="office", subfolder="dc") + config._folder_del(folder="office", subfolder="dc") + assert "dc" not in config.connections["office"] + + +class TestCRUDProfiles: + def test_add_profile(self, config): + config._profiles_add( + id="myprofile", host="", protocol="telnet", + port="23", user="user1", password="pass1", + options="", logs="", tags="", jumphost="" + ) + assert "myprofile" in config.profiles + assert config.profiles["myprofile"]["protocol"] == "telnet" + + def test_del_profile(self, config): + config._profiles_add( + id="temp", host="", protocol="ssh", port="", + user="", password="", options="", logs="", tags="", jumphost="" + ) + config._profiles_del(id="temp") + assert "temp" not in config.profiles + + def test_default_profile_exists(self, config): + assert "default" in config.profiles + + +class TestGetItem: + def test_getitem_node(self, populated_config): + node = populated_config.getitem("router1") + assert node["host"] == "10.0.0.1" + assert "type" not in node # type is stripped + + def test_getitem_folder(self, populated_config): + nodes = populated_config.getitem("@office") + # Should contain server1@office but NOT datacenter (subfolder) + assert "server1@office" in nodes + assert all("type" not in v for v in nodes.values()) + + def test_getitem_subfolder(self, populated_config): + nodes = populated_config.getitem("@datacenter@office") + assert "db1@datacenter@office" in nodes + + def test_getitem_node_in_folder(self, populated_config): + node = populated_config.getitem("server1@office") + assert node["host"] == "10.0.1.1" + + def test_getitem_node_in_subfolder(self, populated_config): + node = populated_config.getitem("db1@datacenter@office") + assert node["host"] == "10.0.2.1" + + def test_getitem_with_profile_extraction(self, tmp_config_dir): + """extract=True resolves @profile references.""" + config_file = tmp_config_dir / "config.json" + data = { + "config": {"case": False, "idletime": 30, "fzf": False}, + "connections": { + "router1": { + "host": "10.0.0.1", "protocol": "ssh", "port": "", + "user": "@office-user", "password": "@office-user", + "options": "", "logs": "", "tags": "", "jumphost": "", + "type": "connection" + } + }, + "profiles": { + "default": {"host": "", "protocol": "ssh", "port": "", + "user": "", "password": "", "options": "", + "logs": "", "tags": "", "jumphost": ""}, + "office-user": {"host": "", "protocol": "ssh", "port": "", + "user": "officeadmin", "password": "officepass", + "options": "", "logs": "", "tags": "", "jumphost": ""} + } + } + config_file.write_text(json.dumps(data, indent=4)) + + from connpy.configfile import configfile + conf = configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk")) + + node = conf.getitem("router1", extract=True) + assert node["user"] == "officeadmin" + assert node["password"] == "officepass" + + def test_getitems_multiple(self, populated_config): + nodes = populated_config.getitems(["router1", "server1@office"]) + assert "router1" in nodes + assert "server1@office" in nodes + + def test_getitems_folder(self, populated_config): + nodes = populated_config.getitems(["@office"]) + assert "server1@office" in nodes + + +class TestGetAll: + def test_getallnodes_no_filter(self, populated_config): + nodes = populated_config._getallnodes() + assert "router1" in nodes + assert "server1@office" in nodes + assert "db1@datacenter@office" in nodes + + def test_getallnodes_string_filter(self, populated_config): + nodes = populated_config._getallnodes("router.*") + assert "router1" in nodes + assert "server1@office" not in nodes + + def test_getallnodes_list_filter(self, populated_config): + nodes = populated_config._getallnodes(["router.*", "db.*"]) + assert "router1" in nodes + assert "db1@datacenter@office" in nodes + assert "server1@office" not in nodes + + def test_getallnodes_filter_invalid_type(self, populated_config): + with pytest.raises(ValueError): + populated_config._getallnodes(123) + + def test_getallfolders(self, populated_config): + folders = populated_config._getallfolders() + assert "@office" in folders + assert "@datacenter@office" in folders + + def test_getallnodesfull(self, populated_config): + nodes = populated_config._getallnodesfull() + assert "router1" in nodes + assert nodes["router1"]["host"] == "10.0.0.1" + + def test_getallnodesfull_with_filter(self, populated_config): + nodes = populated_config._getallnodesfull("router.*") + assert "router1" in nodes + assert "server1@office" not in nodes + + def test_profileused(self, tmp_config_dir): + """Detects nodes using a specific profile.""" + config_file = tmp_config_dir / "config.json" + data = { + "config": {"case": False, "idletime": 30, "fzf": False}, + "connections": { + "router1": { + "host": "10.0.0.1", "protocol": "ssh", "port": "", + "user": "@myprofile", "password": "pass", + "options": "", "logs": "", "tags": "", "jumphost": "", + "type": "connection" + }, + "router2": { + "host": "10.0.0.2", "protocol": "ssh", "port": "", + "user": "admin", "password": "pass", + "options": "", "logs": "", "tags": "", "jumphost": "", + "type": "connection" + } + }, + "profiles": { + "default": {"host": "", "protocol": "ssh", "port": "", + "user": "", "password": "", "options": "", + "logs": "", "tags": "", "jumphost": ""}, + "myprofile": {"host": "", "protocol": "ssh", "port": "", + "user": "profuser", "password": "profpass", + "options": "", "logs": "", "tags": "", "jumphost": ""} + } + } + config_file.write_text(json.dumps(data, indent=4)) + from connpy.configfile import configfile + conf = configfile(conf=str(config_file), key=str(tmp_config_dir / ".osk")) + + used = conf._profileused("myprofile") + assert "router1" in used + assert "router2" not in used + + def test_saveconfig(self, config): + """Save and reload correctly.""" + config._connections_add( + id="test_node", host="1.2.3.4", protocol="ssh", + port="", user="", password="", options="", + logs="", tags="", jumphost="" + ) + result = config._saveconfig(config.file) + assert result == 0 + + # Reload and verify + from connpy.configfile import configfile + reloaded = configfile(conf=config.file, key=config.key) + assert "test_node" in reloaded.connections diff --git a/connpy/tests/test_core.py b/connpy/tests/test_core.py new file mode 100644 index 0000000..54a85ef --- /dev/null +++ b/connpy/tests/test_core.py @@ -0,0 +1,419 @@ +"""Tests for connpy.core module — node and nodes classes.""" +import json +import os +import io +import re +import pytest +from unittest.mock import patch, MagicMock, PropertyMock +from copy import deepcopy + + +# ========================================================================= +# node.__init__ tests +# ========================================================================= + +class TestNodeInit: + def test_basic_init(self): + """Creates node with basic attributes.""" + from connpy.core import node + n = node("router1", "10.0.0.1", user="admin", password="pass1", protocol="ssh") + assert n.unique == "router1" + assert n.host == "10.0.0.1" + assert n.user == "admin" + assert n.protocol == "ssh" + assert n.password == ["pass1"] + + def test_default_protocol(self): + """Default protocol is ssh.""" + from connpy.core import node + n = node("router1", "10.0.0.1") + assert n.protocol == "ssh" + + def test_password_as_list_of_profiles(self, populated_config): + """Password list with @profile references resolves correctly.""" + from connpy.core import node + n = node("router1", "10.0.0.1", password=["@office-user"], + config=populated_config) + assert n.password == ["officepass"] + + def test_password_plain_string(self): + """Plain string password is wrapped in a list.""" + from connpy.core import node + n = node("router1", "10.0.0.1", password="mypass") + assert n.password == ["mypass"] + + def test_node_with_profile(self, populated_config): + """Resolves @profile references for user.""" + from connpy.core import node + n = node("test1", "10.0.0.1", user="@office-user", password="plain", + config=populated_config) + assert n.user == "officeadmin" + + def test_node_tags(self): + """Tags are stored correctly.""" + from connpy.core import node + tags = {"os": "cisco_ios", "prompt": r"Router#"} + n = node("router1", "10.0.0.1", tags=tags) + assert n.tags["os"] == "cisco_ios" + + +# ========================================================================= +# Command generation tests +# ========================================================================= + +class TestCommandGeneration: + def _make_node(self, **kwargs): + from connpy.core import node + defaults = { + "unique": "test", "host": "10.0.0.1", "protocol": "ssh", + "user": "admin", "password": "", "port": "", "options": "", + "jumphost": "", "tags": "", "logs": "" + } + defaults.update(kwargs) + return node(defaults.pop("unique"), defaults.pop("host"), **defaults) + + def test_ssh_cmd_basic(self): + n = self._make_node() + cmd = n._get_cmd() + assert "ssh" in cmd + assert "admin@10.0.0.1" in cmd + + def test_ssh_cmd_port(self): + n = self._make_node(port="2222") + cmd = n._get_cmd() + assert "-p 2222" in cmd + + def test_ssh_cmd_options(self): + n = self._make_node(options="-o StrictHostKeyChecking=no") + cmd = n._get_cmd() + assert "-o StrictHostKeyChecking=no" in cmd + + def test_sftp_cmd_port(self): + n = self._make_node(protocol="sftp", port="2222") + cmd = n._get_cmd() + assert "-P 2222" in cmd # SFTP uses uppercase P + + def test_telnet_cmd(self): + n = self._make_node(protocol="telnet", port="23") + cmd = n._get_cmd() + assert "telnet 10.0.0.1" in cmd + assert "23" in cmd + + def test_kubectl_cmd(self): + n = self._make_node(protocol="kubectl", host="my-pod", tags={"kube_command": "/bin/sh"}) + cmd = n._get_cmd() + assert "kubectl exec" in cmd + assert "my-pod" in cmd + assert "/bin/sh" in cmd + + def test_kubectl_cmd_default_command(self): + n = self._make_node(protocol="kubectl", host="my-pod") + cmd = n._get_cmd() + assert "/bin/bash" in cmd + + def test_docker_cmd(self): + n = self._make_node(protocol="docker", host="my-container", + tags={"docker_command": "/bin/sh"}) + cmd = n._get_cmd() + assert "docker" in cmd + assert "my-container" in cmd + assert "/bin/sh" in cmd + + def test_invalid_protocol_raises(self): + n = self._make_node(protocol="invalid_proto") + with pytest.raises(ValueError, match="Invalid protocol"): + n._get_cmd() + + def test_ssh_cmd_no_user(self): + n = self._make_node(user="") + cmd = n._get_cmd() + assert "10.0.0.1" in cmd + assert "@" not in cmd # No user@ prefix + + +# ========================================================================= +# Password decryption tests +# ========================================================================= + +class TestPasswordDecryption: + def test_passtx_plaintext(self, config): + """Plaintext passwords pass through unchanged.""" + from connpy.core import node + n = node("test", "10.0.0.1", password="plainpass", config=config) + result = n._passtx(["plainpass"]) + assert result == ["plainpass"] + + def test_passtx_encrypted(self, config): + """Encrypted passwords get decrypted.""" + from connpy.core import node + encrypted = config.encrypt("mysecret") + n = node("test", "10.0.0.1", password=encrypted, config=config) + result = n._passtx([encrypted]) + assert result == ["mysecret"] + + def test_passtx_missing_key_raises(self): + """Missing key file raises ValueError.""" + from connpy.core import node + n = node("test", "10.0.0.1", password="pass") + # A password formatted as encrypted but no valid key + with pytest.raises((ValueError, Exception)): + n._passtx(["""b'corrupted_encrypted_data'"""], keyfile="/nonexistent") + + +# ========================================================================= +# Log handling tests +# ========================================================================= + +class TestLogHandling: + def test_logfile_variable_substitution(self): + from connpy.core import node + n = node("router1", "10.0.0.1", user="admin", protocol="ssh", port="22", + logs="/logs/${unique}_${host}_${user}") + result = n._logfile() + assert result == "/logs/router1_10.0.0.1_admin" + + def test_logfile_date_substitution(self): + from connpy.core import node + import datetime + n = node("router1", "10.0.0.1", logs="/logs/${date '%Y'}") + result = n._logfile() + assert datetime.datetime.now().strftime("%Y") in result + + def test_logclean_removes_ansi(self): + from connpy.core import node + n = node("test", "10.0.0.1") + dirty = "\x1B[32mgreen text\x1B[0m" + clean = n._logclean(dirty, var=True) + assert "\x1B" not in clean + assert "green text" in clean + + def test_logclean_removes_backspaces(self): + from connpy.core import node + n = node("test", "10.0.0.1") + dirty = "type\bo" + clean = n._logclean(dirty, var=True) + assert "\b" not in clean + + +# ========================================================================= +# run() and test() with mock pexpect +# ========================================================================= + +class TestNodeRun: + def _make_connected_node(self, mock_pexpect_obj, **kwargs): + """Create a node and mock its _connect to succeed.""" + from connpy.core import node + defaults = { + "unique": "router1", "host": "10.0.0.1", + "protocol": "ssh", "user": "admin", "password": "" + } + defaults.update(kwargs) + n = node(defaults.pop("unique"), defaults.pop("host"), **defaults) + return n + + def test_run_returns_output(self, mock_pexpect): + """run() returns string output.""" + child = mock_pexpect["child"] + pexp = mock_pexpect["pexpect"] + + # Simulate: connect succeeds, command runs, prompt found + child.expect.return_value = 9 # prompt index for ssh + child.logfile_read = None + + from connpy.core import node + n = node("router1", "10.0.0.1", user="admin", password="") + + # Mock _connect to return True and set up child + with patch.object(n, '_connect', return_value=True): + n.child = child + log_buffer = io.BytesIO(b"show version\nRouter v1.0\nrouter#") + n.mylog = log_buffer + child.logfile_read = log_buffer + + with patch.object(n, '_logclean', return_value="Router v1.0"): + output = n.run(["show version"]) + + assert n.status == 0 + assert output == "Router v1.0" + + def test_run_status_1_on_failure(self, mock_pexpect): + """Status 1 when connection fails.""" + from connpy.core import node + n = node("router1", "10.0.0.1", user="admin", password="") + + with patch.object(n, '_connect', return_value="Connection failed code: 1\nrefused"): + output = n.run(["show version"]) + + assert n.status == 1 + assert "refused" in output + + def test_run_with_variables(self, mock_pexpect): + """Variables get substituted in commands.""" + child = mock_pexpect["child"] + child.expect.return_value = 9 + + from connpy.core import node + n = node("router1", "10.0.0.1", user="admin", password="") + + sent_commands = [] + child.sendline.side_effect = lambda cmd: sent_commands.append(cmd) + + with patch.object(n, '_connect', return_value=True): + n.child = child + n.mylog = io.BytesIO(b"output") + with patch.object(n, '_logclean', return_value="output"): + n.run(["show ip route {subnet}"], vars={"subnet": "10.0.0.0/24"}) + + assert "show ip route 10.0.0.0/24" in sent_commands + + def test_run_saves_to_folder(self, mock_pexpect, tmp_path): + """folder param saves log file.""" + child = mock_pexpect["child"] + child.expect.return_value = 9 + + from connpy.core import node + n = node("router1", "10.0.0.1", user="admin", password="") + + with patch.object(n, '_connect', return_value=True): + n.child = child + n.mylog = io.BytesIO(b"log output") + with patch.object(n, '_logclean', return_value="log output"): + n.run(["show version"], folder=str(tmp_path)) + + log_files = list(tmp_path.glob("router1_*.txt")) + assert len(log_files) == 1 + assert "log output" in log_files[0].read_text() + + +class TestNodeTest: + def test_test_returns_dict(self, mock_pexpect): + """test() returns dict of results.""" + child = mock_pexpect["child"] + child.expect.return_value = 0 # prompt found (index 0 in test expects) + + from connpy.core import node + n = node("router1", "10.0.0.1", user="admin", password="") + + with patch.object(n, '_connect', return_value=True): + n.child = child + n.mylog = io.BytesIO(b"1.1.1.1 is up") + with patch.object(n, '_logclean', return_value="1.1.1.1 is up"): + result = n.test(["ping 1.1.1.1"], "1.1.1.1") + + assert isinstance(result, dict) + assert result.get("1.1.1.1") == True + + def test_test_expected_not_found(self, mock_pexpect): + """Expected text not found returns False.""" + child = mock_pexpect["child"] + child.expect.return_value = 0 + + from connpy.core import node + n = node("router1", "10.0.0.1", user="admin", password="") + + with patch.object(n, '_connect', return_value=True): + n.child = child + n.mylog = io.BytesIO(b"some other output") + with patch.object(n, '_logclean', return_value="some other output"): + result = n.test(["ping 1.1.1.1"], "1.1.1.1") + + assert isinstance(result, dict) + assert result.get("1.1.1.1") == False + + +# ========================================================================= +# nodes (parallel) tests +# ========================================================================= + +class TestNodes: + def test_nodes_init(self): + """Creates list of node objects.""" + from connpy.core import nodes + nodes_dict = { + "r1": {"host": "10.0.0.1", "user": "admin", "password": ""}, + "r2": {"host": "10.0.0.2", "user": "admin", "password": ""} + } + mynodes = nodes(nodes_dict) + assert len(mynodes.nodelist) == 2 + assert hasattr(mynodes, "r1") + assert hasattr(mynodes, "r2") + + def test_nodes_run_parallel(self): + """run() executes on all nodes and returns dict.""" + from connpy.core import nodes + + nodes_dict = { + "r1": {"host": "10.0.0.1", "user": "admin", "password": ""}, + "r2": {"host": "10.0.0.2", "user": "admin", "password": ""} + } + mynodes = nodes(nodes_dict) + + # Mock run on each node — must set output AND status on the node + for n in mynodes.nodelist: + original_node = n # capture by value + def make_mock(node_ref): + def mock_run(commands, **kwargs): + node_ref.output = f"output from {node_ref.unique}" + node_ref.status = 0 + return mock_run + n.run = make_mock(n) + + result = mynodes.run(["show version"]) + assert "r1" in result + assert "r2" in result + + def test_nodes_splitlist(self): + """_splitlist divides list correctly.""" + from connpy.core import nodes + mynodes = nodes({"r1": {"host": "1.1.1.1", "user": "", "password": ""}}) + chunks = list(mynodes._splitlist([1, 2, 3, 4, 5], 2)) + assert chunks == [[1, 2], [3, 4], [5]] + + def test_nodes_run_with_vars(self): + """Variables per node and __global__ work.""" + from connpy.core import nodes + + nodes_dict = { + "r1": {"host": "10.0.0.1", "user": "admin", "password": ""}, + } + mynodes = nodes(nodes_dict) + + captured_vars = {} + + def mock_run(commands, vars=None, **kwargs): + captured_vars.update(vars or {}) + mynodes.r1.output = "ok" + mynodes.r1.status = 0 + + mynodes.r1.run = mock_run + + variables = { + "__global__": {"mask": "255.255.255.0"}, + "r1": {"ip": "10.0.0.1"} + } + mynodes.run(["show ip"], vars=variables) + assert captured_vars.get("mask") == "255.255.255.0" + assert captured_vars.get("ip") == "10.0.0.1" + + def test_nodes_on_complete_callback(self): + """on_complete callback fires per node.""" + from connpy.core import nodes + + nodes_dict = { + "r1": {"host": "10.0.0.1", "user": "admin", "password": ""}, + } + mynodes = nodes(nodes_dict) + + completed = [] + + def mock_run(commands, **kwargs): + mynodes.r1.output = "done" + mynodes.r1.status = 0 + + mynodes.r1.run = mock_run + + def on_done(unique, output, status): + completed.append(unique) + + mynodes.run(["show version"], on_complete=on_done) + assert "r1" in completed diff --git a/connpy/tests/test_hooks.py b/connpy/tests/test_hooks.py new file mode 100644 index 0000000..26f3350 --- /dev/null +++ b/connpy/tests/test_hooks.py @@ -0,0 +1,216 @@ +"""Tests for connpy.hooks module — MethodHook and ClassHook.""" +import pytest +from connpy.hooks import MethodHook, ClassHook + + +# ========================================================================= +# MethodHook Tests +# ========================================================================= + +class TestMethodHook: + def test_basic_call(self): + """Decorated function executes normally.""" + @MethodHook + def add(a, b): + return a + b + assert add(2, 3) == 5 + + def test_pre_hook_modifies_args(self): + """Pre-hook can modify arguments before execution.""" + @MethodHook + def greet(name): + return f"Hello {name}" + + def uppercase_hook(name): + return (name.upper(),), {} + + greet.register_pre_hook(uppercase_hook) + assert greet("world") == "Hello WORLD" + + def test_post_hook_modifies_result(self): + """Post-hook can modify the return value.""" + @MethodHook + def compute(x): + return x * 2 + + def double_result(*args, **kwargs): + return kwargs["result"] * 2 + + compute.register_post_hook(double_result) + assert compute(5) == 20 # 5*2=10, then 10*2=20 + + def test_multiple_pre_hooks_order(self): + """Pre-hooks execute in registration order.""" + calls = [] + + @MethodHook + def func(x): + return x + + def hook1(x): + calls.append("hook1") + return (x,), {} + + def hook2(x): + calls.append("hook2") + return (x,), {} + + func.register_pre_hook(hook1) + func.register_pre_hook(hook2) + func(1) + assert calls == ["hook1", "hook2"] + + def test_multiple_post_hooks_order(self): + """Post-hooks execute in registration order.""" + calls = [] + + @MethodHook + def func(x): + return x + + def hook1(*args, **kwargs): + calls.append("hook1") + return kwargs["result"] + + def hook2(*args, **kwargs): + calls.append("hook2") + return kwargs["result"] + + func.register_post_hook(hook1) + func.register_post_hook(hook2) + func(1) + assert calls == ["hook1", "hook2"] + + def test_pre_hook_exception_continues(self, capsys): + """If a pre-hook raises, the function still executes.""" + @MethodHook + def func(x): + return x + 1 + + def bad_hook(x): + raise RuntimeError("broken hook") + + func.register_pre_hook(bad_hook) + # Should not raise — the hook error is printed but execution continues + result = func(5) + assert result == 6 + + def test_post_hook_exception_continues(self, capsys): + """If a post-hook raises, the result is still returned.""" + @MethodHook + def func(x): + return x + 1 + + def bad_hook(*args, **kwargs): + raise RuntimeError("broken post hook") + + func.register_post_hook(bad_hook) + result = func(5) + assert result == 6 + + def test_method_hook_as_instance_method(self): + """MethodHook works as a descriptor on a class.""" + class MyClass: + @MethodHook + def double(self, x): + return x * 2 + + obj = MyClass() + assert obj.double(5) == 10 + + def test_method_hook_instance_hook_registration(self): + """Can register hooks via instance method access.""" + class MyClass: + @MethodHook + def process(self, x): + return x + + def add_ten(*args, **kwargs): + return kwargs["result"] + 10 + + obj = MyClass() + obj.process.register_post_hook(add_ten) + assert obj.process(5) == 15 + + +# ========================================================================= +# ClassHook Tests +# ========================================================================= + +class TestClassHook: + def test_creates_instance(self): + """ClassHook still creates instances normally.""" + @ClassHook + class MyClass: + def __init__(self, value): + self.value = value + + obj = MyClass(42) + assert obj.value == 42 + + def test_modify_future_instances(self): + """modify() affects all future instances.""" + @ClassHook + class MyClass: + def __init__(self): + self.x = 1 + + def set_x_to_99(instance): + instance.x = 99 + + MyClass.modify(set_x_to_99) + obj = MyClass() + assert obj.x == 99 + + def test_modify_does_not_affect_past(self): + """modify() does not affect already-created instances.""" + @ClassHook + class MyClass: + def __init__(self): + self.x = 1 + + old_obj = MyClass() + + def set_x_to_99(instance): + instance.x = 99 + + MyClass.modify(set_x_to_99) + assert old_obj.x == 1 # Not affected + assert MyClass().x == 99 # New instance IS affected + + def test_instance_modify(self): + """instance.modify() only affects that specific instance.""" + @ClassHook + class MyClass: + def __init__(self): + self.x = 1 + + obj1 = MyClass() + obj2 = MyClass() + + obj1.modify(lambda inst: setattr(inst, 'x', 999)) + assert obj1.x == 999 + assert obj2.x == 1 + + def test_multiple_deferred_hooks(self): + """Multiple modify() calls apply in order.""" + @ClassHook + class MyClass: + def __init__(self): + self.log = [] + + MyClass.modify(lambda inst: inst.log.append("first")) + MyClass.modify(lambda inst: inst.log.append("second")) + + obj = MyClass() + assert obj.log == ["first", "second"] + + def test_getattr_delegation(self): + """ClassHook delegates attribute access to the wrapped class.""" + @ClassHook + class MyClass: + class_var = "hello" + def __init__(self): + pass + + assert MyClass.class_var == "hello" diff --git a/connpy/tests/test_plugins.py b/connpy/tests/test_plugins.py new file mode 100644 index 0000000..87a1721 --- /dev/null +++ b/connpy/tests/test_plugins.py @@ -0,0 +1,327 @@ +"""Tests for connpy.plugins module.""" +import os +import textwrap +import pytest +from connpy.plugins import Plugins + + +# --------------------------------------------------------------------------- +# Helper: write a plugin script to a file +# --------------------------------------------------------------------------- +def _write_plugin(path, code): + """Write dedented code to a file.""" + with open(path, "w") as f: + f.write(textwrap.dedent(code)) + + +# ========================================================================= +# verify_script tests +# ========================================================================= + +class TestVerifyScript: + def test_valid_parser_entrypoint(self, tmp_path): + p = tmp_path / "good.py" + _write_plugin(p, """\ + import argparse + + class Parser: + def __init__(self): + self.parser = argparse.ArgumentParser() + + class Entrypoint: + def __init__(self, args, parser, connapp): + pass + """) + plugins = Plugins() + assert plugins.verify_script(str(p)) == False + + def test_valid_preload_only(self, tmp_path): + p = tmp_path / "preload.py" + _write_plugin(p, """\ + class Preload: + def __init__(self, connapp): + pass + """) + plugins = Plugins() + assert plugins.verify_script(str(p)) == False + + def test_valid_all_three(self, tmp_path): + p = tmp_path / "all.py" + _write_plugin(p, """\ + import argparse + + class Parser: + def __init__(self): + self.parser = argparse.ArgumentParser() + + class Entrypoint: + def __init__(self, args, parser, connapp): + pass + + class Preload: + def __init__(self, connapp): + pass + """) + plugins = Plugins() + assert plugins.verify_script(str(p)) == False + + def test_parser_without_entrypoint(self, tmp_path): + p = tmp_path / "bad.py" + _write_plugin(p, """\ + import argparse + + class Parser: + def __init__(self): + self.parser = argparse.ArgumentParser() + """) + plugins = Plugins() + result = plugins.verify_script(str(p)) + assert result # Should be a truthy error string + assert "Entrypoint" in result + + def test_entrypoint_without_parser(self, tmp_path): + p = tmp_path / "bad.py" + _write_plugin(p, """\ + class Entrypoint: + def __init__(self, args, parser, connapp): + pass + """) + plugins = Plugins() + result = plugins.verify_script(str(p)) + assert result + assert "Parser" in result + + def test_no_valid_class(self, tmp_path): + p = tmp_path / "empty.py" + _write_plugin(p, """\ + def some_function(): + pass + """) + plugins = Plugins() + result = plugins.verify_script(str(p)) + assert result + assert "No valid class" in result + + def test_parser_missing_self_parser(self, tmp_path): + p = tmp_path / "bad.py" + _write_plugin(p, """\ + class Parser: + def __init__(self): + self.something = "not parser" + + class Entrypoint: + def __init__(self, args, parser, connapp): + pass + """) + plugins = Plugins() + result = plugins.verify_script(str(p)) + assert result + assert "self.parser" in result + + def test_entrypoint_wrong_args(self, tmp_path): + p = tmp_path / "bad.py" + _write_plugin(p, """\ + import argparse + + class Parser: + def __init__(self): + self.parser = argparse.ArgumentParser() + + class Entrypoint: + def __init__(self, args): + pass + """) + plugins = Plugins() + result = plugins.verify_script(str(p)) + assert result + assert "Entrypoint" in result + + def test_preload_wrong_args(self, tmp_path): + p = tmp_path / "bad.py" + _write_plugin(p, """\ + class Preload: + def __init__(self, connapp, extra): + pass + """) + plugins = Plugins() + result = plugins.verify_script(str(p)) + assert result + assert "Preload" in result + + def test_disallowed_top_level(self, tmp_path): + p = tmp_path / "bad.py" + _write_plugin(p, """\ + MY_GLOBAL = "not allowed" + + class Preload: + def __init__(self, connapp): + pass + """) + plugins = Plugins() + result = plugins.verify_script(str(p)) + assert result + assert "not allowed" in result.lower() or "Plugin can only have" in result + + def test_syntax_error(self, tmp_path): + p = tmp_path / "bad.py" + _write_plugin(p, """\ + def broken( + """) + plugins = Plugins() + result = plugins.verify_script(str(p)) + assert result + assert "Syntax error" in result + + def test_if_name_main_allowed(self, tmp_path): + p = tmp_path / "good.py" + _write_plugin(p, """\ + class Preload: + def __init__(self, connapp): + pass + + if __name__ == "__main__": + print("standalone") + """) + plugins = Plugins() + assert plugins.verify_script(str(p)) == False + + def test_other_if_not_allowed(self, tmp_path): + p = tmp_path / "bad.py" + _write_plugin(p, """\ + import sys + + if sys.platform == "linux": + pass + + class Preload: + def __init__(self, connapp): + pass + """) + plugins = Plugins() + result = plugins.verify_script(str(p)) + assert result + assert "__name__" in result + + +# ========================================================================= +# Import and loading tests +# ========================================================================= + +class TestPluginLoading: + def test_import_from_path(self, tmp_path): + p = tmp_path / "mymod.py" + _write_plugin(p, """\ + MY_VAR = 42 + """) + plugins = Plugins() + module = plugins._import_from_path(str(p)) + assert module.MY_VAR == 42 + + def test_import_plugins_to_argparse(self, tmp_path): + """Valid plugins get loaded into argparse.""" + import argparse + + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + _write_plugin(plugin_dir / "myplugin.py", """\ + import argparse + + class Parser: + def __init__(self): + self.parser = argparse.ArgumentParser(description="My plugin") + + class Entrypoint: + def __init__(self, args, parser, connapp): + pass + """) + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + + plugins = Plugins() + plugins._import_plugins_to_argparse(str(plugin_dir), subparsers) + + assert "myplugin" in plugins.plugins + assert "myplugin" in plugins.plugin_parsers + + def test_plugin_name_collision(self, tmp_path): + """Plugin with same name as existing subcommand is skipped.""" + import argparse + + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + _write_plugin(plugin_dir / "existcmd.py", """\ + import argparse + + class Parser: + def __init__(self): + self.parser = argparse.ArgumentParser() + + class Entrypoint: + def __init__(self, args, parser, connapp): + pass + """) + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + subparsers.add_parser("existcmd") # Already taken + + plugins = Plugins() + plugins._import_plugins_to_argparse(str(plugin_dir), subparsers) + + assert "existcmd" not in plugins.plugins + + def test_preload_registration(self, tmp_path): + """Preload class gets registered in preloads dict.""" + import argparse + + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + _write_plugin(plugin_dir / "preloader.py", """\ + class Preload: + def __init__(self, connapp): + pass + """) + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + + plugins = Plugins() + plugins._import_plugins_to_argparse(str(plugin_dir), subparsers) + + assert "preloader" in plugins.preloads + + def test_invalid_plugin_skipped(self, tmp_path, capsys): + """Invalid plugin is skipped with error message.""" + import argparse + + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + _write_plugin(plugin_dir / "badplugin.py", """\ + MY_GLOBAL = "bad" + """) + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + + plugins = Plugins() + plugins._import_plugins_to_argparse(str(plugin_dir), subparsers) + + assert "badplugin" not in plugins.plugins + captured = capsys.readouterr() + assert "Failed to load plugin" in captured.err or "Failed to load plugin" in captured.out + + def test_empty_directory(self, tmp_path): + """Empty directory doesn't cause errors.""" + import argparse + + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + + plugins = Plugins() + plugins._import_plugins_to_argparse(str(plugin_dir), subparsers) + + assert len(plugins.plugins) == 0 diff --git a/connpy/tests/test_printer.py b/connpy/tests/test_printer.py new file mode 100644 index 0000000..26c06a9 --- /dev/null +++ b/connpy/tests/test_printer.py @@ -0,0 +1,50 @@ +"""Tests for connpy.printer module.""" +import sys +from io import StringIO +from connpy import printer + + +class TestPrinter: + def test_info_output(self, capsys): + printer.info("hello world") + captured = capsys.readouterr() + assert "[i] hello world" in captured.out + + def test_success_output(self, capsys): + printer.success("done") + captured = capsys.readouterr() + assert "[✓] done" in captured.out + + def test_warning_output(self, capsys): + printer.warning("careful") + captured = capsys.readouterr() + assert "[!] careful" in captured.out + + def test_error_output(self, capsys): + printer.error("failed") + captured = capsys.readouterr() + assert "[✗] failed" in captured.err + + def test_debug_output(self, capsys): + printer.debug("debug info") + captured = capsys.readouterr() + assert "[d] debug info" in captured.out + + def test_start_output(self, capsys): + printer.start("starting") + captured = capsys.readouterr() + assert "[+] starting" in captured.out + + def test_custom_output(self, capsys): + printer.custom("TAG", "custom message") + captured = capsys.readouterr() + assert "[TAG] custom message" in captured.out + + def test_multiline_indentation(self, capsys): + printer.info("line1\nline2\nline3") + captured = capsys.readouterr() + lines = captured.out.strip().split("\n") + assert lines[0] == "[i] line1" + # Second line should be indented by len("[i] ") = 4 chars + assert lines[1].startswith(" line2") + assert lines[2].startswith(" line3") diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..f683744 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +testpaths = connpy/tests +python_files = test_*.py +python_classes = Test* +python_functions = test_*