diff --git a/connpy/_version.py b/connpy/_version.py index b1a5057..610335e 100644 --- a/connpy/_version.py +++ b/connpy/_version.py @@ -1 +1 @@ -__version__ = "5.1b6" +__version__ = "6.0.0b1" diff --git a/connpy/core.py b/connpy/core.py index 96de1ab..5efaebb 100755 --- a/connpy/core.py +++ b/connpy/core.py @@ -14,7 +14,10 @@ from pathlib import Path from copy import deepcopy from .hooks import ClassHook, MethodHook import io +import asyncio +import fcntl from . import printer +from .tunnels import LocalStream #functions and classes @@ -189,23 +192,54 @@ class node: @MethodHook def _logclean(self, logfile, var = False): - #Remove special ascii characters and other stuff from logfile. + # Remove special ascii characters and process terminal cursor movements to clean logs. if var == False: t = open(logfile, "r").read() else: t = logfile - while t.find("\b") != -1: - t = re.sub('[^\b]\b', '', t) - t = t.replace("\n","",1) - t = t.replace("\a","") - t = t.replace('\n\n', '\n') - t = re.sub(r'.\[K', '', t) - ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/ ]*[@-~])') - t = ansi_escape.sub('', t) - t = t.lstrip(" \n\r") - t = t.replace("\r","") - t = t.replace("\x0E","") - t = t.replace("\x0F","") + + lines = t.split('\n') + cleaned_lines = [] + + # Regex to capture: ANSI sequences, control characters (\r, \b, etc), and plain text chunks + token_re = re.compile(r'(\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/ ]*[@-~])|\r|\b|\x7f|[\x00-\x1F]|[^\x1B\r\b\x7f\x00-\x1F]+)') + + for line in lines: + buffer = [] + cursor = 0 + + for token in token_re.findall(line): + if token == '\r': + cursor = 0 + elif token in ('\b', '\x7f'): + if cursor > 0: + cursor -= 1 + elif token == '\x1B[D': # Left Arrow + if cursor > 0: + cursor -= 1 + elif token == '\x1B[C': # Right Arrow + if cursor < len(buffer): + cursor += 1 + elif token == '\x1B[K': # Clear to end of line + buffer = buffer[:cursor] + elif token.startswith('\x1B'): + # Ignore other ANSI sequences (colors, etc) + continue + elif len(token) == 1 and ord(token) < 32: + # Ignore other non-printable control chars + continue + else: + # Regular printable text + for char in token: + if cursor == len(buffer): + buffer.append(char) + else: + buffer[cursor] = char + cursor += 1 + cleaned_lines.append("".join(buffer)) + + t = "\n".join(cleaned_lines).replace('\n\n', '\n').strip() + if var == False: d = open(logfile, "w") d.write(t) @@ -248,48 +282,193 @@ class node: sleep(1) - @MethodHook - def interact(self, debug = False, logger = None): - ''' - Allow user to interact with the node directly, mostly used by connection manager. + def _setup_interact_environment(self, debug=False, logger=None, async_mode=False): + size = re.search('columns=([0-9]+).*lines=([0-9]+)',str(os.get_terminal_size())) + self.child.setwinsize(int(size.group(2)),int(size.group(1))) + if logger: + port_str = f":{self.port}" if self.port and self.protocol not in ["ssm", "kubectl", "docker"] else "" + logger("success", f"Connected to {self.unique} at {self.host}{port_str} via: {self.protocol}") - ### Optional Parameters: - - - debug (bool): If True, display all the connecting information - before interact. Default False. - - logger (callable): Optional callback for status reporting. - ''' - connect = self._connect(debug = debug, logger = logger) - if connect == True: - size = re.search('columns=([0-9]+).*lines=([0-9]+)',str(os.get_terminal_size())) - self.child.setwinsize(int(size.group(2)),int(size.group(1))) - if logger: - port_str = f":{self.port}" if self.port and self.protocol not in ["ssm", "kubectl", "docker"] else "" - logger("success", f"Connected to {self.unique} at {self.host}{port_str} via: {self.protocol}") - - if 'logfile' in dir(self): - # Initialize self.mylog - if not 'mylog' in dir(self): - self.mylog = io.BytesIO() + if 'logfile' in dir(self): + # Initialize self.mylog + if not 'mylog' in dir(self): + self.mylog = io.BytesIO() + if not async_mode: self.child.logfile_read = self.mylog # Start the _savelog thread log_thread = threading.Thread(target=self._savelog) log_thread.daemon = True log_thread.start() - if 'missingtext' in dir(self): - print(self.child.after.decode(), end='') - if self.idletime > 0: - x = threading.Thread(target=self._keepalive) - x.daemon = True - x.start() - if debug: + if 'missingtext' in dir(self): + print(self.child.after.decode(), end='') + if self.idletime > 0 and not async_mode: + x = threading.Thread(target=self._keepalive) + x.daemon = True + x.start() + if debug: + if 'mylog' in dir(self): print(self.mylog.getvalue().decode()) - self.child.interact(input_filter=self._filter) - if 'logfile' in dir(self): - with open(self.logfile, "w") as f: - f.write(self._logclean(self.mylog.getvalue().decode(), True)) + def _teardown_interact_environment(self): + if 'logfile' in dir(self) and hasattr(self, 'mylog'): + with open(self.logfile, "w") as f: + f.write(self._logclean(self.mylog.getvalue().decode(), True)) + + async def _async_interact_loop(self, local_stream, resize_callback): + local_stream.setup(resize_callback=resize_callback) + try: + child_fd = self.child.child_fd + + # 1. Flush ghost buffer (Clean UX) + ghost_buffer = b'' + if getattr(self, 'missingtext', False): + # If we are missing the password, we MUST show the password prompt + ghost_buffer = (self.child.after or b'') + (self.child.buffer or b'') + else: + # We auto-logged in. Hide the messy password negotiation and just keep any pending live stream. + ghost_buffer = self.child.buffer or b'' + + # Fix user's pet peeve: Strip leading newlines to avoid the empty lines + # the router echoes after receiving the password or blank line. + if not getattr(self, 'missingtext', False): + ghost_buffer = ghost_buffer.lstrip(b'\r\n ') + + if ghost_buffer: + # Add a single clean newline so it doesn't merge with the Connected message + await local_stream.write(b'\r\n' + ghost_buffer) + if hasattr(self, 'mylog'): + self.mylog.write(b'\n' + ghost_buffer) + + self.child.buffer = b'' + self.child.before = b'' + + # 2. Set child fd non-blocking + flags = fcntl.fcntl(child_fd, fcntl.F_GETFL) + fcntl.fcntl(child_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + loop = asyncio.get_running_loop() + child_reader_queue = asyncio.Queue() + + def _child_read_ready(): + try: + data = os.read(child_fd, 4096) + if data: + child_reader_queue.put_nowait(data) + else: + child_reader_queue.put_nowait(b'') + except BlockingIOError: + pass + except OSError: + child_reader_queue.put_nowait(b'') + + loop.add_reader(child_fd, _child_read_ready) + self.lastinput = time() + + async def ingress_task(): + while True: + data = await local_stream.read() + if not data: + break + try: + os.write(child_fd, data) + except OSError: + break + self.lastinput = time() + + async def egress_task(): + # Continue stripping newlines from the live stream until we hit real text + skip_newlines = not getattr(self, 'missingtext', False) and not ghost_buffer + while True: + data = await child_reader_queue.get() + if not data: + break + + if skip_newlines: + stripped = data.lstrip(b'\r\n') + if stripped: + skip_newlines = False + data = stripped + else: + continue + + await local_stream.write(data) + if hasattr(self, 'mylog'): + self.mylog.write(data) + + async def keepalive_task(): + if self.idletime <= 0: + return + while True: + await asyncio.sleep(1) + if time() - self.lastinput >= self.idletime: + try: + self.child.sendcontrol("e") + self.lastinput = time() + except Exception: + pass + + async def savelog_task(): + if not hasattr(self, 'logfile') or not hasattr(self, 'mylog'): + return + prev_size = 0 + while True: + await asyncio.sleep(5) + current_size = self.mylog.tell() + if current_size != prev_size: + try: + with open(self.logfile, "w") as f: + f.write(self._logclean(self.mylog.getvalue().decode(), True)) + prev_size = current_size + except Exception: + pass + + try: + # gather runs until any task completes (or we just let them run until EOF breaks them) + # Ingress breaks on user EOF. Egress breaks on child EOF. + # We want to exit if either happens, so return_exceptions=False, but we need to cancel the others. + tasks = [ + asyncio.create_task(ingress_task()), + asyncio.create_task(egress_task()), + asyncio.create_task(keepalive_task()), + asyncio.create_task(savelog_task()) + ] + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + for p in pending: + p.cancel() + finally: + loop.remove_reader(child_fd) + try: + flags = fcntl.fcntl(child_fd, fcntl.F_GETFL) + fcntl.fcntl(child_fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK) + except Exception: + pass + finally: + local_stream.teardown() + + + @MethodHook + def interact(self, debug=False, logger=None): + ''' + Asynchronous interactive session using Smart Tunnel architecture. + Allows multiplexing I/O and handling SIGWINCH events locally without blocking. + ''' + connect = self._connect(debug=debug, logger=logger) + if connect == True: + try: + self._setup_interact_environment(debug=debug, logger=logger, async_mode=True) + + local_stream = LocalStream() + + def resize_callback(rows, cols): + try: + self.child.setwinsize(rows, cols) + except Exception: + pass + + asyncio.run(self._async_interact_loop(local_stream, resize_callback)) + finally: + self._teardown_interact_environment() else: if logger: logger("error", str(connect)) diff --git a/connpy/grpc_layer/server.py b/connpy/grpc_layer/server.py index fc6b282..2a9b9b9 100644 --- a/connpy/grpc_layer/server.py +++ b/connpy/grpc_layer/server.py @@ -61,10 +61,13 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer): @handle_errors def interact_node(self, request_iterator, context): import sys - import select import os + import asyncio from connpy.core import node from ..services.profile_service import ProfileService + from connpy.tunnels import RemoteStream + import queue + import threading # Fetch first setup packet try: @@ -83,11 +86,11 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer): base_node_id = params.get("base_node") # Valid attributes that a node object accepts valid_attrs = ['host', 'options', 'logs', 'password', 'port', 'protocol', 'user', 'jumphost'] - + fallback_id = f"{unique_id}@remote" if unique_id == "dynamic" and params.get("host"): fallback_id = f"dynamic-{params.get('host')}@remote" - + if base_node_id: # Look up the base node in config and use its full data nodes = self.service.config._getallnodes(base_node_id) @@ -97,14 +100,14 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer): for attr in valid_attrs: if attr in params: device[attr] = params[attr] - + if "tags" in params: device_tags = device.get("tags", {}) if not isinstance(device_tags, dict): device_tags = {} device_tags.update(params["tags"]) device["tags"] = device_tags - + node_name = params.get("name", base_node_id) n = node(node_name, **device, config=self.service.config) else: @@ -138,34 +141,10 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer): if connect != True: yield connpy_pb2.InteractResponse(success=False, error_message=str(connect)) return - + # Signal successful connection to the client yield connpy_pb2.InteractResponse(success=True) - import threading - import queue - - stdin_queue = queue.Queue() - running = True - - def read_requests(): - try: - for req in request_iterator: - if not running: - break - if req.cols > 0 and req.rows > 0: - try: - n.child.setwinsize(req.rows, req.cols) - except Exception: - pass - if req.stdin_data: - stdin_queue.put(req.stdin_data) - except grpc.RpcError: - pass - - t = threading.Thread(target=read_requests, daemon=True) - t.start() - # Set initial window size if provided if first_req.cols > 0 and first_req.rows > 0: try: @@ -173,32 +152,34 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer): except Exception: pass - try: - while n.child.isalive() and running: - r, _, _ = select.select([n.child.child_fd], [], [], 0.05) - if r: - try: - data = os.read(n.child.child_fd, 4096) - if not data: - break - yield connpy_pb2.InteractResponse(stdout_data=data) - except OSError: - break - - while not stdin_queue.empty(): - data = stdin_queue.get_nowait() - try: - os.write(n.child.child_fd, data) - except OSError: - running = False - break - finally: - running = False - try: - n.child.terminate(force=True) - except Exception: - pass + response_queue = queue.Queue() + remote_stream = RemoteStream(request_iterator, response_queue) + def run_async_loop(): + try: + n._setup_interact_environment(debug=debug, logger=None, async_mode=True) + def resize_callback(rows, cols): + try: + n.child.setwinsize(rows, cols) + except Exception: + pass + + asyncio.run(n._async_interact_loop(remote_stream, resize_callback)) + except Exception as e: + pass + finally: + n._teardown_interact_environment() + response_queue.put(None) # Signal EOF + + t_loop = threading.Thread(target=run_async_loop, daemon=True) + t_loop.start() + + while True: + data = response_queue.get() + if data is None: + printer.console.print(f"[debug][DEBUG][/debug] gRPC interact_node session closed for: [bold cyan]{unique_id}[/bold cyan]") + break + yield connpy_pb2.InteractResponse(stdout_data=data) @handle_errors def list_nodes(self, request, context): f = request.filter_str if request.filter_str else None @@ -691,6 +672,8 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer): daemon=True ) ai_thread.start() + except grpc.RpcError: + pass except Exception as e: print(f"Request Listener Error: {e}") finally: diff --git a/connpy/services/node_service.py b/connpy/services/node_service.py index 0ed96e2..2a04c03 100644 --- a/connpy/services/node_service.py +++ b/connpy/services/node_service.py @@ -182,7 +182,7 @@ class NodeService(BaseService): n = node(unique_id, **resolved_data, config=self.config) if sftp: n.protocol = "sftp" - + n.interact(debug=debug, logger=logger) def move_node(self, src_id, dst_id, copy=False): diff --git a/connpy/tunnels.py b/connpy/tunnels.py new file mode 100644 index 0000000..4ff9308 --- /dev/null +++ b/connpy/tunnels.py @@ -0,0 +1,171 @@ +import asyncio +import os +import sys +import termios +import tty +import signal +import struct +import fcntl + +class LocalStream: + """ + Asynchronous stream wrapper for local stdin/stdout. + Handles terminal raw mode, async I/O, and SIGWINCH signals. + """ + def __init__(self): + self.stdin_fd = sys.stdin.fileno() + self.stdout_fd = sys.stdout.fileno() + self.original_tty_settings = None + self.resize_callback = None + self._reader_queue = asyncio.Queue() + self._loop = None + + def setup(self, resize_callback=None): + self._loop = asyncio.get_running_loop() + self.resize_callback = resize_callback + + # Save original terminal settings + try: + self.original_tty_settings = termios.tcgetattr(self.stdin_fd) + tty.setraw(self.stdin_fd) + except termios.error: + # Not a TTY, maybe piped or redirected + pass + + # Set stdin non-blocking + flags = fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) + fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + # Setup read callback + self._loop.add_reader(self.stdin_fd, self._read_ready) + + # Register SIGWINCH + if resize_callback: + try: + self._loop.add_signal_handler(signal.SIGWINCH, self._handle_winch) + except (NotImplementedError, RuntimeError): + # signal handling not supported on some loops (e.g., Windows Proactor) + pass + + def teardown(self): + if self._loop: + try: + self._loop.remove_reader(self.stdin_fd) + except Exception: + pass + if self.resize_callback: + try: + self._loop.remove_signal_handler(signal.SIGWINCH) + except Exception: + pass + + # Restore terminal settings + if self.original_tty_settings is not None: + try: + termios.tcsetattr(self.stdin_fd, termios.TCSADRAIN, self.original_tty_settings) + except termios.error: + pass + + # Restore blocking mode for stdin + try: + flags = fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) + fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK) + except Exception: + pass + + def _read_ready(self): + try: + # Read whatever is available + data = os.read(self.stdin_fd, 4096) + if data: + self._reader_queue.put_nowait(data) + else: + self._reader_queue.put_nowait(b'') # EOF + except BlockingIOError: + pass + except OSError: + self._reader_queue.put_nowait(b'') # EOF on error + + async def read(self) -> bytes: + """Asynchronously read bytes from stdin.""" + return await self._reader_queue.get() + + async def write(self, data: bytes): + """Asynchronously write bytes to stdout.""" + if not data: + return + + try: + os.write(self.stdout_fd, data) + except OSError: + pass + + def _handle_winch(self): + if self.resize_callback: + try: + # Use ioctl to get the current window size + s = struct.pack("HHHH", 0, 0, 0, 0) + a = fcntl.ioctl(self.stdout_fd, termios.TIOCGWINSZ, s) + rows, cols, _, _ = struct.unpack("HHHH", a) + + # We schedule the callback safely inside the asyncio loop + # instead of running it raw in the signal handler + self._loop.call_soon(self.resize_callback, rows, cols) + except Exception: + pass + + +import threading + +class RemoteStream: + """ + Asynchronous stream wrapper for gRPC remote connections. + Bridges the blocking gRPC iterators with the async _async_interact_loop. + """ + def __init__(self, request_iterator, response_queue): + self.request_iterator = request_iterator + self.response_queue = response_queue + self.running = True + self._reader_queue = asyncio.Queue() + self.resize_callback = None + self._loop = None + self.t = None + + def setup(self, resize_callback=None): + self._loop = asyncio.get_running_loop() + self.resize_callback = resize_callback + + def read_requests(): + try: + for req in self.request_iterator: + if not self.running: + break + if req.cols > 0 and req.rows > 0: + if self.resize_callback: + self._loop.call_soon_threadsafe(self.resize_callback, req.rows, req.cols) + if req.stdin_data: + self._loop.call_soon_threadsafe(self._reader_queue.put_nowait, req.stdin_data) + except Exception: + pass + finally: + if self._loop and not self._loop.is_closed(): + try: + self._loop.call_soon_threadsafe(self._reader_queue.put_nowait, b'') + except RuntimeError: + pass + + self.t = threading.Thread(target=read_requests, daemon=True) + self.t.start() + + def teardown(self): + self.running = False + self.response_queue.put(None) # Signal EOF + + async def read(self) -> bytes: + """Asynchronously read bytes from the gRPC iterator queue.""" + return await self._reader_queue.get() + + async def write(self, data: bytes): + """Asynchronously write bytes to the gRPC response queue.""" + if data: + self.response_queue.put(data) diff --git a/docs/connpy/cli/ai_handler.html b/docs/connpy/cli/ai_handler.html index 32cd613..78549bc 100644 --- a/docs/connpy/cli/ai_handler.html +++ b/docs/connpy/cli/ai_handler.html @@ -3,7 +3,7 @@
- +