refactor(core): stabilize gRPC streaming, plugin invocation, and CLI UX
- Implement threaded plugin execution with Queue-based streaming in PluginService - Refactor remote logger to preserve ANSI colors and fix TTY line endings (\r\n) - Intelligent terminal filtering: disable SSM screen-clearing filter after success - Sanitize SSH-only flags in core.py when using SFTP protocol - Rewrite completion tree with pre/post-node states and flag deduplication - Update gRPC unit tests to match new streaming response structure
This commit is contained in:
@@ -184,9 +184,37 @@ def _build_tree(nodes, folders, profiles, plugins, configdir):
|
||||
"folders": None,
|
||||
}
|
||||
|
||||
# --- Connect (default command) ---
|
||||
# Long flags are offered; short forms (-d/-t) only used for navigation.
|
||||
# Two states: before node (offer nodes + remaining long flags)
|
||||
# after node (offer only remaining long flags, no more nodes)
|
||||
connect_flags_long = ["--debug", "--sftp"]
|
||||
connect_flags_all = ["--debug", "-d", "--sftp", "-t"]
|
||||
|
||||
# Post-node: only offer remaining long flags
|
||||
connect_after_node = {"__exclude_used__": True}
|
||||
for f in connect_flags_all:
|
||||
connect_after_node[f] = connect_after_node
|
||||
|
||||
# Pre-node: offer nodes + remaining long flags, consume node → post-node state
|
||||
connect_dict = {"__exclude_used__": True}
|
||||
connect_dict["__extra__"] = lambda w: (
|
||||
list(nodes) + list(folders) + (list(plugins.keys()) if plugins else [])
|
||||
)
|
||||
connect_dict["*"] = connect_after_node
|
||||
for f in connect_flags_all:
|
||||
connect_dict[f] = connect_dict
|
||||
|
||||
# --- Main Tree ---
|
||||
return {
|
||||
# Root: offer nodes + long flags; after a node go to post-node state
|
||||
"__extra__": lambda w: list(nodes) + list(folders) + (list(plugins.keys()) if plugins else []),
|
||||
"*": connect_after_node,
|
||||
|
||||
"--debug": connect_dict,
|
||||
"-d": connect_dict,
|
||||
"--sftp": connect_dict,
|
||||
"-t": connect_dict,
|
||||
|
||||
"--add": {"profile": _profile_values},
|
||||
"--del": {"profile": _profile_values, "__extra__": _nodes_folders},
|
||||
|
||||
+9
-2
@@ -348,7 +348,8 @@ class node:
|
||||
x.start()
|
||||
if debug:
|
||||
if 'mylog' in dir(self):
|
||||
print(self.mylog.getvalue().decode())
|
||||
if not async_mode:
|
||||
print(self.mylog.getvalue().decode())
|
||||
|
||||
def _teardown_interact_environment(self):
|
||||
if 'logfile' in dir(self) and hasattr(self, 'mylog'):
|
||||
@@ -760,7 +761,12 @@ class node:
|
||||
elif self.protocol == "sftp":
|
||||
cmd += " -P " + self.port
|
||||
if self.options:
|
||||
cmd += " " + self.options
|
||||
opts = self.options
|
||||
if self.protocol == "sftp":
|
||||
# Strip SSH-only flags that sftp doesn't support
|
||||
opts = re.sub(r'(?<!\S)-[XxtTAaNf]\b', '', opts).strip()
|
||||
if opts:
|
||||
cmd += " " + opts
|
||||
if self.jumphost:
|
||||
cmd += " " + self.jumphost
|
||||
user_host = f"{self.user}@{self.host}" if self.user else self.host
|
||||
@@ -875,6 +881,7 @@ class node:
|
||||
if logger:
|
||||
logger("debug", f"Command:\n{cmd}")
|
||||
self.mylog = io.BytesIO()
|
||||
self.mylog.write(f"[i] [DEBUG] Command:\r\n {cmd}\r\n".encode())
|
||||
child.logfile_read = self.mylog
|
||||
|
||||
|
||||
|
||||
@@ -139,7 +139,39 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
||||
if sftp:
|
||||
n.protocol = "sftp"
|
||||
|
||||
connect = n._connect(debug=debug)
|
||||
# Build a logger that captures debug messages as ANSI-colored bytes for the client
|
||||
debug_chunks = []
|
||||
if debug:
|
||||
from io import StringIO
|
||||
from rich.console import Console as RichConsole
|
||||
from ..printer import connpy_theme
|
||||
from .. import printer as _printer
|
||||
|
||||
def remote_logger(msg_type, message):
|
||||
buf = StringIO()
|
||||
c = RichConsole(file=buf, force_terminal=True, width=120, theme=connpy_theme)
|
||||
if msg_type == "debug":
|
||||
c.print(_printer._format_multiline("i", f"[DEBUG] {message}", style="info"))
|
||||
elif msg_type == "success":
|
||||
c.print(_printer._format_multiline("✓", message, style="success"))
|
||||
elif msg_type == "error":
|
||||
c.print(_printer._format_multiline("✗", message, style="error"))
|
||||
else:
|
||||
c.print(str(message))
|
||||
rendered = buf.getvalue()
|
||||
if rendered:
|
||||
# Raw TTY needs \r\n instead of \n
|
||||
rendered = rendered.replace('\n', '\r\n')
|
||||
debug_chunks.append(rendered.encode())
|
||||
else:
|
||||
remote_logger = None
|
||||
|
||||
connect = n._connect(debug=debug, logger=remote_logger)
|
||||
|
||||
# Send debug output to client before checking result (always show the command)
|
||||
for chunk in debug_chunks:
|
||||
yield connpy_pb2.InteractResponse(stdout_data=chunk)
|
||||
|
||||
if connect != True:
|
||||
yield connpy_pb2.InteractResponse(success=False, error_message=str(connect))
|
||||
return
|
||||
|
||||
+48
-24
@@ -86,25 +86,37 @@ class NodeStub:
|
||||
|
||||
old_tty = termios.tcgetattr(sys.stdin)
|
||||
try:
|
||||
import time
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
response_iterator = self.stub.interact_node(request_generator())
|
||||
|
||||
# First response is connection status
|
||||
# First phase: Wait for connection status, print early data
|
||||
try:
|
||||
first_res = next(response_iterator)
|
||||
if first_res.success:
|
||||
# Connection established on server, show success message
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
printer.success(conn_msg)
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
else:
|
||||
# Connection failed on server
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
printer.error(f"Connection failed: {first_res.error_message}")
|
||||
return
|
||||
for res in response_iterator:
|
||||
if res.stdout_data:
|
||||
data = res.stdout_data
|
||||
if debug:
|
||||
data = data.replace(b'\x1b[H\x1b[2J', b'').replace(b'\x1bc', b'').replace(b'\x1b[3J', b'')
|
||||
os.write(sys.stdout.fileno(), data)
|
||||
|
||||
if res.success:
|
||||
# Connection established on server, show success message
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
printer.success(conn_msg)
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
break
|
||||
|
||||
if res.error_message:
|
||||
# Connection failed on server
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
printer.error(f"Connection failed: {res.error_message}")
|
||||
return
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
# Second phase: Stream active session
|
||||
# Clear screen filter is only applied before success (Phase 1).
|
||||
# Once the user has a prompt, Ctrl+L must work normally.
|
||||
for res in response_iterator:
|
||||
if res.stdout_data:
|
||||
os.write(sys.stdout.fileno(), res.stdout_data)
|
||||
@@ -160,25 +172,37 @@ class NodeStub:
|
||||
|
||||
old_tty = termios.tcgetattr(sys.stdin)
|
||||
try:
|
||||
import time
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
response_iterator = self.stub.interact_node(request_generator())
|
||||
|
||||
# First response is connection status
|
||||
# First phase: Wait for connection status, print early data
|
||||
try:
|
||||
first_res = next(response_iterator)
|
||||
if first_res.success:
|
||||
# Connection established on server, show success message
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
printer.success(conn_msg)
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
else:
|
||||
# Connection failed on server
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
printer.error(f"Connection failed: {first_res.error_message}")
|
||||
return
|
||||
for res in response_iterator:
|
||||
if res.stdout_data:
|
||||
data = res.stdout_data
|
||||
if debug:
|
||||
data = data.replace(b'\x1b[H\x1b[2J', b'').replace(b'\x1bc', b'').replace(b'\x1b[3J', b'')
|
||||
os.write(sys.stdout.fileno(), data)
|
||||
|
||||
if res.success:
|
||||
# Connection established on server, show success message
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
printer.success(conn_msg)
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
break
|
||||
|
||||
if res.error_message:
|
||||
# Connection failed on server
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
printer.error(f"Connection failed: {res.error_message}")
|
||||
return
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
# Second phase: Stream active session
|
||||
# Clear screen filter is only applied before success (Phase 1).
|
||||
# Once the user has a prompt, Ctrl+L must work normally.
|
||||
for res in response_iterator:
|
||||
if res.stdout_data:
|
||||
os.write(sys.stdout.fileno(), res.stdout_data)
|
||||
|
||||
@@ -233,25 +233,44 @@ class PluginService(BaseService):
|
||||
from rich.console import Console
|
||||
|
||||
from rich.console import Console
|
||||
buf = io.StringIO()
|
||||
import queue
|
||||
import threading
|
||||
|
||||
q = queue.Queue()
|
||||
|
||||
class QueueIO(io.StringIO):
|
||||
def write(self, s):
|
||||
q.put(s)
|
||||
return len(s)
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
buf = QueueIO()
|
||||
old_console = printer._get_console()
|
||||
old_err_console = printer._get_err_console()
|
||||
|
||||
printer.set_thread_console(Console(file=buf, theme=printer.connpy_theme, force_terminal=True))
|
||||
printer.set_thread_err_console(Console(file=buf, theme=printer.connpy_theme, force_terminal=True))
|
||||
printer.set_thread_stream(buf)
|
||||
def run_plugin():
|
||||
printer.set_thread_console(Console(file=buf, theme=printer.connpy_theme, force_terminal=True))
|
||||
printer.set_thread_err_console(Console(file=buf, theme=printer.connpy_theme, force_terminal=True))
|
||||
printer.set_thread_stream(buf)
|
||||
try:
|
||||
if hasattr(module, "Entrypoint"):
|
||||
module.Entrypoint(args, parser, app)
|
||||
except BaseException as e:
|
||||
if not isinstance(e, SystemExit):
|
||||
import traceback
|
||||
printer.err_console.print(traceback.format_exc())
|
||||
finally:
|
||||
printer.set_thread_console(old_console)
|
||||
printer.set_thread_err_console(old_err_console)
|
||||
printer.set_thread_stream(None)
|
||||
q.put(None)
|
||||
|
||||
try:
|
||||
if hasattr(module, "Entrypoint"):
|
||||
module.Entrypoint(args, parser, app)
|
||||
except BaseException as e:
|
||||
if not isinstance(e, SystemExit):
|
||||
import traceback
|
||||
printer.err_console.print(traceback.format_exc())
|
||||
finally:
|
||||
printer.set_thread_console(old_console)
|
||||
printer.set_thread_err_console(old_err_console)
|
||||
printer.set_thread_stream(None)
|
||||
t = threading.Thread(target=run_plugin, daemon=True)
|
||||
t.start()
|
||||
|
||||
for line in buf.getvalue().splitlines(keepends=True):
|
||||
yield line
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is None:
|
||||
break
|
||||
yield item
|
||||
|
||||
@@ -85,8 +85,8 @@ class TestStubsMessageFormatting:
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.success = True
|
||||
mock_resp.stdout_data = b''
|
||||
stub.stub.interact_node.return_value = iter([mock_resp])
|
||||
|
||||
with patch("connpy.printer.success") as mock_success:
|
||||
with patch("sys.stdin.fileno", return_value=0):
|
||||
mock_select.return_value = ([], [], [])
|
||||
|
||||
Reference in New Issue
Block a user