diff --git a/.gitignore b/.gitignore
index d4c8802..46bdc24 100644
--- a/.gitignore
+++ b/.gitignore
@@ -161,3 +161,4 @@ async_interact_plan.md
repo_consolidado_limpio.md
connpy_roadmap.md
MULTI_USER_PLAN.md
+COPILOT_PLAN.md
diff --git a/connpy/core.py b/connpy/core.py
index 037ba7f..2781e12 100755
--- a/connpy/core.py
+++ b/connpy/core.py
@@ -785,7 +785,7 @@ class node:
context_mode[0] = (context_mode[0] + 1) % 3
event.app.invalidate()
- @bindings.add('escape')
+ @bindings.add('escape', eager=True)
def _(event):
cancelled[0] = True
event.app.exit(result='')
@@ -898,10 +898,11 @@ class node:
while not ai_task.done():
try:
key = os.read(sys.stdin.fileno(), 1024)
- if b'\x03' in key:
+ if b'\x03' in key or b'\x1b' in key:
cancelled = True
ai_task.cancel()
- console.print("\n[dim]Copilot cancelled via Ctrl+C.[/dim]")
+ msg = "Ctrl+C" if b'\x03' in key else "Esc"
+ console.print(f"\n[dim]Copilot cancelled via {msg}.[/dim]")
break
except OSError:
pass
@@ -953,7 +954,7 @@ class node:
confirm_session = PromptSession()
confirm_bindings = KeyBindings()
- @confirm_bindings.add('escape')
+ @confirm_bindings.add('escape', eager=True)
def _(event):
event.app.exit(result='n')
@@ -994,11 +995,21 @@ class node:
if cmds_to_edit:
target_cmd = "\n".join(cmds_to_edit)
+ edit_bindings = KeyBindings()
+ @edit_bindings.add('c-j')
+ def _(event):
+ event.app.exit(result=event.app.current_buffer.text)
+
+ @edit_bindings.add('escape', eager=True)
+ def _(event):
+ event.app.exit(result='')
+
try:
edited_cmd = await edit_session.prompt_async(
- HTML("Edit commands (Alt+Enter or Esc,Enter to submit):\n"),
+ HTML("Edit commands (Ctrl+Enter to submit, Esc to cancel):\n"),
default=target_cmd,
- multiline=True
+ multiline=True,
+ key_bindings=edit_bindings
)
if edited_cmd.strip():
os.write(child_fd, b'\x15')
diff --git a/connpy/grpc_layer/stubs.py b/connpy/grpc_layer/stubs.py
index 05f04fb..ebf90db 100644
--- a/connpy/grpc_layer/stubs.py
+++ b/connpy/grpc_layer/stubs.py
@@ -292,7 +292,7 @@ class NodeStub:
context_mode[0] = (context_mode[0] + 1) % 3
event.app.invalidate()
- @bindings.add('escape')
+ @bindings.add('escape', eager=True)
def _(event):
event.app.exit(result='')
@@ -363,10 +363,11 @@ class NodeStub:
# 1. Read input for Ctrl+C
try:
key = os.read(sys.stdin.fileno(), 1024)
- if b'\x03' in key:
+ if b'\x03' in key or b'\x1b' in key:
cancelled = True
request_queue.put(connpy_pb2.InteractRequest(copilot_question="CANCEL"))
- console.print("\n[dim]Copilot cancelled via Ctrl+C. Disconnecting...[/dim]")
+ msg = "Ctrl+C" if b'\x03' in key else "Esc"
+ console.print(f"\n[dim]Copilot cancelled via {msg}.[/dim]")
break
except OSError:
pass
@@ -421,7 +422,7 @@ class NodeStub:
try:
confirm_session = PromptSession()
confirm_bindings = KeyBindings()
- @confirm_bindings.add('escape')
+ @confirm_bindings.add('escape', eager=True)
def _(event):
event.app.exit(result='n')
@@ -454,10 +455,19 @@ class NodeStub:
if cmds_to_edit:
target_cmd = "\n".join(cmds_to_edit)
try:
+ edit_bindings = KeyBindings()
+ @edit_bindings.add('c-j')
+ def _(event):
+ event.app.exit(result=event.app.current_buffer.text)
+ @edit_bindings.add('escape', eager=True)
+ def _(event):
+ event.app.exit(result='')
+
edited_cmd = edit_session.prompt(
- HTML("Edit commands (Alt+Enter or Esc,Enter to submit):\n"),
+ HTML("Edit commands (Ctrl+Enter to submit, Esc to cancel):\n"),
default=target_cmd,
- multiline=True
+ multiline=True,
+ key_bindings=edit_bindings
)
if edited_cmd.strip():
action_sent = "custom:" + edited_cmd.strip()
@@ -737,7 +747,7 @@ class NodeStub:
context_mode[0] = (context_mode[0] + 1) % 3
event.app.invalidate()
- @bindings.add('escape')
+ @bindings.add('escape', eager=True)
def _(event):
event.app.exit(result='')
@@ -806,10 +816,11 @@ class NodeStub:
while True:
try:
key = os.read(sys.stdin.fileno(), 1024)
- if b'\x03' in key:
+ if b'\x03' in key or b'\x1b' in key:
cancelled = True
request_queue.put(connpy_pb2.InteractRequest(copilot_question="CANCEL"))
- console.print("\n[dim]Copilot cancelled via Ctrl+C. Disconnecting...[/dim]")
+ msg = "Ctrl+C" if b'\x03' in key else "Esc"
+ console.print(f"\n[dim]Copilot cancelled via {msg}.[/dim]")
break
except OSError:
pass
@@ -862,7 +873,7 @@ class NodeStub:
try:
confirm_session = PromptSession()
confirm_bindings = KeyBindings()
- @confirm_bindings.add('escape')
+ @confirm_bindings.add('escape', eager=True)
def _(event):
event.app.exit(result='n')
@@ -895,10 +906,19 @@ class NodeStub:
if cmds_to_edit:
target_cmd = "\n".join(cmds_to_edit)
try:
+ edit_bindings = KeyBindings()
+ @edit_bindings.add('c-j')
+ def _(event):
+ event.app.exit(result=event.app.current_buffer.text)
+ @edit_bindings.add('escape', eager=True)
+ def _(event):
+ event.app.exit(result='')
+
edited_cmd = edit_session.prompt(
- HTML("Edit commands (Alt+Enter or Esc,Enter to submit):\n"),
+ HTML("Edit commands (Ctrl+Enter to submit, Esc to cancel):\n"),
default=target_cmd,
- multiline=True
+ multiline=True,
+ key_bindings=edit_bindings
)
if edited_cmd.strip():
action_sent = "custom:" + edited_cmd.strip()
diff --git a/connpy/tests/test_ai_copilot.py b/connpy/tests/test_ai_copilot.py
new file mode 100644
index 0000000..524bcbe
--- /dev/null
+++ b/connpy/tests/test_ai_copilot.py
@@ -0,0 +1,136 @@
+import pytest
+from unittest.mock import MagicMock, patch
+import json
+import asyncio
+
+from connpy.ai import ai
+from connpy.core import node
+
+class DummyConfig:
+ def __init__(self):
+ self.config = {"ai": {"engineer_api_key": "test_key", "engineer_model": "test_model"}}
+ self.defaultdir = "/tmp"
+
+@pytest.fixture
+def mock_completion():
+ with patch('connpy.ai.completion') as mock:
+ yield mock
+
+def test_ask_copilot_tool_call(mock_completion):
+ agent = ai(DummyConfig())
+
+ # Setup mock response for streaming
+ class MockDelta:
+ def __init__(self, content):
+ self.content = content
+
+ class MockChoice:
+ def __init__(self, content):
+ self.delta = MockDelta(content)
+
+ class MockChunk:
+ def __init__(self, content):
+ self.choices = [MockChoice(content)]
+
+ mock_completion.return_value = [
+ MockChunk("Check the interfaces and running config."),
+ MockChunk("\nshow ip int br\nshow run\n"),
+ MockChunk("low")
+ ]
+
+ result = agent.ask_copilot("Router#", "What do I do?")
+
+ if result["error"]:
+ print(f"ERROR OCCURRED: {result['error']}")
+
+ assert result["error"] is None
+ assert result["guide"] == "Check the interfaces and running config."
+ assert result["risk_level"] == "low"
+ assert result["commands"] == ["show ip int br", "show run"]
+
+def test_ask_copilot_fallback(mock_completion):
+ agent = ai(DummyConfig())
+
+ # Setup mock response for streaming
+ class MockDelta:
+ def __init__(self, content):
+ self.content = content
+
+ class MockChoice:
+ def __init__(self, content):
+ self.delta = MockDelta(content)
+
+ class MockChunk:
+ def __init__(self, content):
+ self.choices = [MockChoice(content)]
+
+ mock_completion.return_value = [
+ MockChunk("Here is some text response instead of tool call.")
+ ]
+
+ result = agent.ask_copilot("Router#", "What do I do?")
+
+ if result["error"]:
+ print(f"ERROR OCCURRED: {result['error']}")
+
+ assert result["error"] is None
+ assert result["guide"] == "Here is some text response instead of tool call."
+ assert result["risk_level"] == "low"
+
+def test_logclean_ansi():
+ c = node("test_node", "1.2.3.4")
+ raw = "Router#\x1b[K\x1b[m show ip"
+ clean = c._logclean(raw, var=True)
+ assert "\x1b" not in clean
+
+def test_ingress_task_interception():
+ async def run_test():
+ c = node("test_node", "1.2.3.4")
+ c.mylog = MagicMock()
+ c.mylog.getvalue.return_value = b"Some session log"
+ c.unique = "test_node"
+ c.host = "1.2.3.4"
+ c.tags = {"os": "cisco_ios"}
+
+ class MockStream:
+ def __init__(self):
+ self.data = [b"a", b"b", b"\x00", b"c", b""]
+ async def read(self):
+ if self.data:
+ return self.data.pop(0)
+ return b""
+ def setup(self, resize_callback):
+ pass
+
+ stream = MockStream()
+
+ called_copilot = False
+ async def mock_handler(buffer, node_info, s, child_fd):
+ nonlocal called_copilot
+ called_copilot = True
+ assert buffer == "Some session log"
+ assert node_info["os"] == "cisco_ios"
+
+ c.child = MagicMock()
+ c.child.child_fd = 123
+ c.child.after = b""
+ c.child.buffer = b""
+
+ async def mock_ingress():
+ while True:
+ data = await stream.read()
+ if not data:
+ break
+
+ if mock_handler and b'\x00' in data:
+ buffer = c.mylog.getvalue().decode()
+ node_info = {"name": getattr(c, 'unique', 'unknown'), "host": getattr(c, 'host', 'unknown')}
+ if isinstance(getattr(c, 'tags', None), dict):
+ node_info["os"] = c.tags.get("os", "unknown")
+ await mock_handler(buffer, node_info, stream, c.child.child_fd)
+ continue
+
+ await mock_ingress()
+ assert called_copilot
+
+ asyncio.run(run_test())