diff --git a/.gitignore b/.gitignore index d4c8802..46bdc24 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,4 @@ async_interact_plan.md repo_consolidado_limpio.md connpy_roadmap.md MULTI_USER_PLAN.md +COPILOT_PLAN.md diff --git a/connpy/core.py b/connpy/core.py index 037ba7f..2781e12 100755 --- a/connpy/core.py +++ b/connpy/core.py @@ -785,7 +785,7 @@ class node: context_mode[0] = (context_mode[0] + 1) % 3 event.app.invalidate() - @bindings.add('escape') + @bindings.add('escape', eager=True) def _(event): cancelled[0] = True event.app.exit(result='') @@ -898,10 +898,11 @@ class node: while not ai_task.done(): try: key = os.read(sys.stdin.fileno(), 1024) - if b'\x03' in key: + if b'\x03' in key or b'\x1b' in key: cancelled = True ai_task.cancel() - console.print("\n[dim]Copilot cancelled via Ctrl+C.[/dim]") + msg = "Ctrl+C" if b'\x03' in key else "Esc" + console.print(f"\n[dim]Copilot cancelled via {msg}.[/dim]") break except OSError: pass @@ -953,7 +954,7 @@ class node: confirm_session = PromptSession() confirm_bindings = KeyBindings() - @confirm_bindings.add('escape') + @confirm_bindings.add('escape', eager=True) def _(event): event.app.exit(result='n') @@ -994,11 +995,21 @@ class node: if cmds_to_edit: target_cmd = "\n".join(cmds_to_edit) + edit_bindings = KeyBindings() + @edit_bindings.add('c-j') + def _(event): + event.app.exit(result=event.app.current_buffer.text) + + @edit_bindings.add('escape', eager=True) + def _(event): + event.app.exit(result='') + try: edited_cmd = await edit_session.prompt_async( - HTML("Edit commands (Alt+Enter or Esc,Enter to submit):\n"), + HTML("Edit commands (Ctrl+Enter to submit, Esc to cancel):\n"), default=target_cmd, - multiline=True + multiline=True, + key_bindings=edit_bindings ) if edited_cmd.strip(): os.write(child_fd, b'\x15') diff --git a/connpy/grpc_layer/stubs.py b/connpy/grpc_layer/stubs.py index 05f04fb..ebf90db 100644 --- a/connpy/grpc_layer/stubs.py +++ b/connpy/grpc_layer/stubs.py @@ -292,7 +292,7 @@ class NodeStub: context_mode[0] = (context_mode[0] + 1) % 3 event.app.invalidate() - @bindings.add('escape') + @bindings.add('escape', eager=True) def _(event): event.app.exit(result='') @@ -363,10 +363,11 @@ class NodeStub: # 1. Read input for Ctrl+C try: key = os.read(sys.stdin.fileno(), 1024) - if b'\x03' in key: + if b'\x03' in key or b'\x1b' in key: cancelled = True request_queue.put(connpy_pb2.InteractRequest(copilot_question="CANCEL")) - console.print("\n[dim]Copilot cancelled via Ctrl+C. Disconnecting...[/dim]") + msg = "Ctrl+C" if b'\x03' in key else "Esc" + console.print(f"\n[dim]Copilot cancelled via {msg}.[/dim]") break except OSError: pass @@ -421,7 +422,7 @@ class NodeStub: try: confirm_session = PromptSession() confirm_bindings = KeyBindings() - @confirm_bindings.add('escape') + @confirm_bindings.add('escape', eager=True) def _(event): event.app.exit(result='n') @@ -454,10 +455,19 @@ class NodeStub: if cmds_to_edit: target_cmd = "\n".join(cmds_to_edit) try: + edit_bindings = KeyBindings() + @edit_bindings.add('c-j') + def _(event): + event.app.exit(result=event.app.current_buffer.text) + @edit_bindings.add('escape', eager=True) + def _(event): + event.app.exit(result='') + edited_cmd = edit_session.prompt( - HTML("Edit commands (Alt+Enter or Esc,Enter to submit):\n"), + HTML("Edit commands (Ctrl+Enter to submit, Esc to cancel):\n"), default=target_cmd, - multiline=True + multiline=True, + key_bindings=edit_bindings ) if edited_cmd.strip(): action_sent = "custom:" + edited_cmd.strip() @@ -737,7 +747,7 @@ class NodeStub: context_mode[0] = (context_mode[0] + 1) % 3 event.app.invalidate() - @bindings.add('escape') + @bindings.add('escape', eager=True) def _(event): event.app.exit(result='') @@ -806,10 +816,11 @@ class NodeStub: while True: try: key = os.read(sys.stdin.fileno(), 1024) - if b'\x03' in key: + if b'\x03' in key or b'\x1b' in key: cancelled = True request_queue.put(connpy_pb2.InteractRequest(copilot_question="CANCEL")) - console.print("\n[dim]Copilot cancelled via Ctrl+C. Disconnecting...[/dim]") + msg = "Ctrl+C" if b'\x03' in key else "Esc" + console.print(f"\n[dim]Copilot cancelled via {msg}.[/dim]") break except OSError: pass @@ -862,7 +873,7 @@ class NodeStub: try: confirm_session = PromptSession() confirm_bindings = KeyBindings() - @confirm_bindings.add('escape') + @confirm_bindings.add('escape', eager=True) def _(event): event.app.exit(result='n') @@ -895,10 +906,19 @@ class NodeStub: if cmds_to_edit: target_cmd = "\n".join(cmds_to_edit) try: + edit_bindings = KeyBindings() + @edit_bindings.add('c-j') + def _(event): + event.app.exit(result=event.app.current_buffer.text) + @edit_bindings.add('escape', eager=True) + def _(event): + event.app.exit(result='') + edited_cmd = edit_session.prompt( - HTML("Edit commands (Alt+Enter or Esc,Enter to submit):\n"), + HTML("Edit commands (Ctrl+Enter to submit, Esc to cancel):\n"), default=target_cmd, - multiline=True + multiline=True, + key_bindings=edit_bindings ) if edited_cmd.strip(): action_sent = "custom:" + edited_cmd.strip() diff --git a/connpy/tests/test_ai_copilot.py b/connpy/tests/test_ai_copilot.py new file mode 100644 index 0000000..524bcbe --- /dev/null +++ b/connpy/tests/test_ai_copilot.py @@ -0,0 +1,136 @@ +import pytest +from unittest.mock import MagicMock, patch +import json +import asyncio + +from connpy.ai import ai +from connpy.core import node + +class DummyConfig: + def __init__(self): + self.config = {"ai": {"engineer_api_key": "test_key", "engineer_model": "test_model"}} + self.defaultdir = "/tmp" + +@pytest.fixture +def mock_completion(): + with patch('connpy.ai.completion') as mock: + yield mock + +def test_ask_copilot_tool_call(mock_completion): + agent = ai(DummyConfig()) + + # Setup mock response for streaming + class MockDelta: + def __init__(self, content): + self.content = content + + class MockChoice: + def __init__(self, content): + self.delta = MockDelta(content) + + class MockChunk: + def __init__(self, content): + self.choices = [MockChoice(content)] + + mock_completion.return_value = [ + MockChunk("Check the interfaces and running config."), + MockChunk("\nshow ip int br\nshow run\n"), + MockChunk("low") + ] + + result = agent.ask_copilot("Router#", "What do I do?") + + if result["error"]: + print(f"ERROR OCCURRED: {result['error']}") + + assert result["error"] is None + assert result["guide"] == "Check the interfaces and running config." + assert result["risk_level"] == "low" + assert result["commands"] == ["show ip int br", "show run"] + +def test_ask_copilot_fallback(mock_completion): + agent = ai(DummyConfig()) + + # Setup mock response for streaming + class MockDelta: + def __init__(self, content): + self.content = content + + class MockChoice: + def __init__(self, content): + self.delta = MockDelta(content) + + class MockChunk: + def __init__(self, content): + self.choices = [MockChoice(content)] + + mock_completion.return_value = [ + MockChunk("Here is some text response instead of tool call.") + ] + + result = agent.ask_copilot("Router#", "What do I do?") + + if result["error"]: + print(f"ERROR OCCURRED: {result['error']}") + + assert result["error"] is None + assert result["guide"] == "Here is some text response instead of tool call." + assert result["risk_level"] == "low" + +def test_logclean_ansi(): + c = node("test_node", "1.2.3.4") + raw = "Router#\x1b[K\x1b[m show ip" + clean = c._logclean(raw, var=True) + assert "\x1b" not in clean + +def test_ingress_task_interception(): + async def run_test(): + c = node("test_node", "1.2.3.4") + c.mylog = MagicMock() + c.mylog.getvalue.return_value = b"Some session log" + c.unique = "test_node" + c.host = "1.2.3.4" + c.tags = {"os": "cisco_ios"} + + class MockStream: + def __init__(self): + self.data = [b"a", b"b", b"\x00", b"c", b""] + async def read(self): + if self.data: + return self.data.pop(0) + return b"" + def setup(self, resize_callback): + pass + + stream = MockStream() + + called_copilot = False + async def mock_handler(buffer, node_info, s, child_fd): + nonlocal called_copilot + called_copilot = True + assert buffer == "Some session log" + assert node_info["os"] == "cisco_ios" + + c.child = MagicMock() + c.child.child_fd = 123 + c.child.after = b"" + c.child.buffer = b"" + + async def mock_ingress(): + while True: + data = await stream.read() + if not data: + break + + if mock_handler and b'\x00' in data: + buffer = c.mylog.getvalue().decode() + node_info = {"name": getattr(c, 'unique', 'unknown'), "host": getattr(c, 'host', 'unknown')} + if isinstance(getattr(c, 'tags', None), dict): + node_info["os"] = c.tags.get("os", "unknown") + await mock_handler(buffer, node_info, stream, c.child.child_fd) + continue + + await mock_ingress() + assert called_copilot + + asyncio.run(run_test())