Files
connpy/connpy/grpc_layer/stubs.py
T

703 lines
27 KiB
Python

import grpc
import queue
import threading
from functools import wraps
from google.protobuf.empty_pb2 import Empty
from . import connpy_pb2, connpy_pb2_grpc, remote_plugin_pb2, remote_plugin_pb2_grpc
from .utils import to_value, from_value, to_struct, from_struct
from ..services.exceptions import ConnpyError
from ..hooks import MethodHook
from .. import printer
def handle_errors(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except grpc.RpcError as e:
# Re-raise gRPC errors as native ConnpyError to keep CLI handlers agnostic
details = e.details()
# Identify the host if available on the instance
instance = args[0] if args else None
host = getattr(instance, "remote_host", "remote host")
# Make common gRPC errors more readable
if "failed to connect to all addresses" in details:
simplified = f"Failed to connect to remote host at {host} (Connection refused)"
elif "Method not found" in details:
simplified = f"Remote server at {host} is using an incompatible version"
elif "Deadline Exceeded" in details:
simplified = f"Request to {host} timed out"
else:
simplified = details
raise ConnpyError(simplified)
return wrapper
class NodeStub:
def __init__(self, channel, remote_host, config=None):
self.stub = connpy_pb2_grpc.NodeServiceStub(channel)
self.remote_host = remote_host
self.config = config
@handle_errors
def connect_node(self, unique_id, sftp=False, debug=False, logger=None):
import sys
import select
import tty
import termios
import os
import threading
def request_generator():
cols, rows = 80, 24
try:
size = os.get_terminal_size()
cols, rows = size.columns, size.lines
except OSError:
pass
yield connpy_pb2.InteractRequest(
id=unique_id, sftp=sftp, debug=debug, cols=cols, rows=rows
)
while True:
r, _, _ = select.select([sys.stdin.fileno()], [], [])
if r:
try:
data = os.read(sys.stdin.fileno(), 1024)
if not data:
break
yield connpy_pb2.InteractRequest(stdin_data=data)
except OSError:
break
# Fetch node details for the connection message
try:
node_details = self.get_node_details(unique_id)
host = node_details.get("host", "unknown")
port = str(node_details.get("port", ""))
protocol = "sftp" if sftp else node_details.get("protocol", "ssh")
port_str = f":{port}" if port and protocol not in ["ssm", "kubectl", "docker"] else ""
conn_msg = f"Connected to {unique_id} at {host}{port_str} via: {protocol}"
except Exception:
conn_msg = f"Connected to {unique_id}"
old_tty = termios.tcgetattr(sys.stdin)
try:
tty.setraw(sys.stdin.fileno())
response_iterator = self.stub.interact_node(request_generator())
# First response is connection status
try:
first_res = next(response_iterator)
if first_res.success:
# Connection established on server, show success message
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
printer.success(conn_msg)
tty.setraw(sys.stdin.fileno())
else:
# Connection failed on server
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
printer.error(f"Connection failed: {first_res.error_message}")
return
except StopIteration:
return
for res in response_iterator:
if res.stdout_data:
os.write(sys.stdout.fileno(), res.stdout_data)
finally:
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
@handle_errors
def connect_dynamic(self, connection_params, debug=False):
import sys
import select
import tty
import termios
import os
import json
params_json = json.dumps(connection_params)
def request_generator():
cols, rows = 80, 24
try:
size = os.get_terminal_size()
cols, rows = size.columns, size.lines
except OSError:
pass
yield connpy_pb2.InteractRequest(
id="dynamic", debug=debug, cols=cols, rows=rows,
connection_params_json=params_json
)
while True:
r, _, _ = select.select([sys.stdin.fileno()], [], [])
if r:
try:
data = os.read(sys.stdin.fileno(), 1024)
if not data:
break
yield connpy_pb2.InteractRequest(stdin_data=data)
except OSError:
break
# Prepare connection message
try:
node_name = connection_params.get("name", "dynamic@remote")
host = connection_params.get("host", "dynamic")
port = str(connection_params.get("port", ""))
protocol = connection_params.get("protocol", "ssh")
port_str = f":{port}" if port and protocol not in ["ssm", "kubectl", "docker"] else ""
conn_msg = f"Connected to {node_name} at {host}{port_str} via: {protocol}"
except Exception:
node_name = connection_params.get("name", "dynamic@remote") if isinstance(connection_params, dict) else "dynamic@remote"
conn_msg = f"Connected to {node_name}"
old_tty = termios.tcgetattr(sys.stdin)
try:
tty.setraw(sys.stdin.fileno())
response_iterator = self.stub.interact_node(request_generator())
# First response is connection status
try:
first_res = next(response_iterator)
if first_res.success:
# Connection established on server, show success message
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
printer.success(conn_msg)
tty.setraw(sys.stdin.fileno())
else:
# Connection failed on server
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
printer.error(f"Connection failed: {first_res.error_message}")
return
except StopIteration:
return
for res in response_iterator:
if res.stdout_data:
os.write(sys.stdout.fileno(), res.stdout_data)
finally:
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
@MethodHook
@handle_errors
def list_nodes(self, filter_str=None, format_str=None):
req = connpy_pb2.FilterRequest(filter_str=filter_str or "", format_str=format_str or "")
return from_value(self.stub.list_nodes(req).data) or []
@MethodHook
@handle_errors
def list_folders(self, filter_str=None):
req = connpy_pb2.FilterRequest(filter_str=filter_str or "")
return from_value(self.stub.list_folders(req).data) or []
@handle_errors
def get_node_details(self, unique_id):
return from_struct(self.stub.get_node_details(connpy_pb2.IdRequest(id=unique_id)).data)
@handle_errors
def explode_unique(self, unique_id):
return from_value(self.stub.explode_unique(connpy_pb2.IdRequest(id=unique_id)).data)
@handle_errors
def validate_parent_folder(self, unique_id):
self.stub.validate_parent_folder(connpy_pb2.IdRequest(id=unique_id))
@handle_errors
def generate_cache(self, nodes=None, folders=None, profiles=None):
# 1. Update remote cache on server
self.stub.generate_cache(Empty())
# 2. Update local fzf/text cache files
# If no data provided, we fetch it all from remote to sync local files
if nodes is None and folders is None and profiles is None:
nodes = self.list_nodes()
folders = self.list_folders()
# We don't have direct access to ProfileStub here, but usually
# node cache is what matters for fzf. We'll fetch profiles if we can.
# For now, let's sync what we have.
if nodes is not None or folders is not None or profiles is not None:
self.config._generate_nodes_cache(nodes=nodes, folders=folders, profiles=profiles)
def _trigger_local_cache_sync(self):
"""Helper to fetch remote data and update local fzf cache files after a change."""
try:
nodes = self.list_nodes()
folders = self.list_folders()
self.generate_cache(nodes=nodes, folders=folders)
except Exception:
# Failure to sync cache shouldn't break the main operation's success feedback
pass
@handle_errors
def add_node(self, unique_id, data, is_folder=False):
req = connpy_pb2.NodeRequest(id=unique_id, data=to_struct(data), is_folder=is_folder)
self.stub.add_node(req)
self._trigger_local_cache_sync()
@handle_errors
def update_node(self, unique_id, data):
req = connpy_pb2.NodeRequest(id=unique_id, data=to_struct(data), is_folder=False)
self.stub.update_node(req)
self._trigger_local_cache_sync()
@handle_errors
def delete_node(self, unique_id, is_folder=False):
req = connpy_pb2.DeleteRequest(id=unique_id, is_folder=is_folder)
self.stub.delete_node(req)
self._trigger_local_cache_sync()
@handle_errors
def move_node(self, src_id, dst_id, copy=False):
req = connpy_pb2.MoveRequest(src_id=src_id, dst_id=dst_id, copy=copy)
self.stub.move_node(req)
self._trigger_local_cache_sync()
@handle_errors
def bulk_add(self, ids, hosts, common_data):
req = connpy_pb2.BulkRequest(ids=ids, hosts=hosts, common_data=to_struct(common_data))
self.stub.bulk_add(req)
self._trigger_local_cache_sync()
@handle_errors
def set_reserved_names(self, names):
self.stub.set_reserved_names(connpy_pb2.ListRequest(items=names))
self._trigger_local_cache_sync()
@handle_errors
def full_replace(self, connections, profiles):
req = connpy_pb2.FullReplaceRequest(
connections=to_struct(connections),
profiles=to_struct(profiles)
)
self.stub.full_replace(req)
self._trigger_local_cache_sync()
@handle_errors
def get_inventory(self):
resp = self.stub.get_inventory(Empty())
return {
"connections": from_struct(resp.connections),
"profiles": from_struct(resp.profiles)
}
class ProfileStub:
def __init__(self, channel, remote_host, node_stub=None):
self.stub = connpy_pb2_grpc.ProfileServiceStub(channel)
self.remote_host = remote_host
self.node_stub = node_stub
@handle_errors
def list_profiles(self, filter_str=None):
req = connpy_pb2.FilterRequest(filter_str=filter_str or "")
return from_value(self.stub.list_profiles(req).data) or []
@handle_errors
def get_profile(self, name, resolve=True):
req = connpy_pb2.ProfileRequest(name=name, resolve=resolve)
return from_struct(self.stub.get_profile(req).data)
@handle_errors
def add_profile(self, name, data):
req = connpy_pb2.NodeRequest(id=name, data=to_struct(data))
self.stub.add_profile(req)
if self.node_stub:
self.node_stub._trigger_local_cache_sync()
@handle_errors
def resolve_node_data(self, node_data):
req = connpy_pb2.StructRequest(data=to_struct(node_data))
return from_struct(self.stub.resolve_node_data(req).data)
@handle_errors
def delete_profile(self, name):
req = connpy_pb2.IdRequest(id=name)
self.stub.delete_profile(req)
if self.node_stub:
self.node_stub._trigger_local_cache_sync()
@handle_errors
def update_profile(self, name, data):
req = connpy_pb2.NodeRequest(id=name, data=to_struct(data))
self.stub.update_profile(req)
if self.node_stub:
self.node_stub._trigger_local_cache_sync()
class ConfigStub:
def __init__(self, channel, remote_host):
self.stub = connpy_pb2_grpc.ConfigServiceStub(channel)
self.remote_host = remote_host
@handle_errors
def get_settings(self):
return from_struct(self.stub.get_settings(Empty()).data)
@handle_errors
def update_setting(self, key, value):
self.stub.update_setting(connpy_pb2.UpdateRequest(key=key, value=to_value(value)))
@handle_errors
def get_default_dir(self):
return self.stub.get_default_dir(Empty()).value
@handle_errors
def set_config_folder(self, folder):
self.stub.set_config_folder(connpy_pb2.StringRequest(value=folder))
@handle_errors
def encrypt_password(self, password):
return self.stub.encrypt_password(connpy_pb2.StringRequest(value=password)).value
class PluginStub:
def __init__(self, channel, remote_host):
self.stub = connpy_pb2_grpc.PluginServiceStub(channel)
self.remote_stub = remote_plugin_pb2_grpc.RemotePluginServiceStub(channel)
self.remote_host = remote_host
@handle_errors
def list_plugins(self):
return from_value(self.stub.list_plugins(Empty()).data)
@handle_errors
def add_plugin(self, name, source_file, update=False):
# Read the local file content to send it to the server
with open(source_file, "r") as f:
content = f.read()
# Use source_file as a marker for "content-inside"
marker_content = f"---CONTENT---\n{content}"
req = connpy_pb2.PluginRequest(name=name, source_file=marker_content, update=update)
self.stub.add_plugin(req)
@handle_errors
def delete_plugin(self, name):
self.stub.delete_plugin(connpy_pb2.IdRequest(id=name))
@handle_errors
def enable_plugin(self, name):
self.stub.enable_plugin(connpy_pb2.IdRequest(id=name))
@handle_errors
def disable_plugin(self, name):
self.stub.disable_plugin(connpy_pb2.IdRequest(id=name))
@handle_errors
def get_plugin_source(self, name):
resp = self.remote_stub.get_plugin_source(remote_plugin_pb2.IdRequest(id=name))
return resp.value
@handle_errors
def invoke_plugin(self, name, args_namespace):
import json
args_dict = {k: v for k, v in vars(args_namespace).items()
if isinstance(v, (str, int, float, bool, list, type(None)))}
if hasattr(args_namespace, "func") and hasattr(args_namespace.func, "__name__"):
args_dict["__func_name__"] = args_namespace.func.__name__
req = remote_plugin_pb2.PluginInvokeRequest(name=name, args_json=json.dumps(args_dict))
for chunk in self.remote_stub.invoke_plugin(req):
yield chunk.text
class ExecutionStub:
def __init__(self, channel, remote_host):
self.stub = connpy_pb2_grpc.ExecutionServiceStub(channel)
self.remote_host = remote_host
@handle_errors
def run_commands(self, nodes_filter, commands, variables=None, parallel=10, timeout=10, folder=None, prompt=None, **kwargs):
nodes_list = [nodes_filter] if isinstance(nodes_filter, str) else list(nodes_filter)
req = connpy_pb2.RunRequest(
nodes=nodes_list,
commands=commands,
folder=folder or "",
prompt=prompt or "",
parallel=parallel,
timeout=timeout,
name=kwargs.get("name", "")
)
if variables is not None:
req.vars.CopyFrom(to_struct(variables))
final_results = {}
on_complete = kwargs.get("on_node_complete")
for response in self.stub.run_commands(req):
if on_complete:
on_complete(response.unique_id, response.output, response.status)
final_results[response.unique_id] = {
"output": response.output,
"status": response.status
}
return final_results
@handle_errors
def test_commands(self, nodes_filter, commands, expected, variables=None, parallel=10, timeout=10, prompt=None, **kwargs):
nodes_list = [nodes_filter] if isinstance(nodes_filter, str) else list(nodes_filter)
req = connpy_pb2.TestRequest(
nodes=nodes_list,
commands=commands,
expected=expected if isinstance(expected, list) else [expected],
folder=kwargs.get("folder", ""),
prompt=prompt or "",
parallel=parallel,
timeout=timeout,
name=kwargs.get("name", "")
)
if variables is not None:
req.vars.CopyFrom(to_struct(variables))
final_results = {}
on_complete = kwargs.get("on_node_complete")
for response in self.stub.test_commands(req):
result_dict = from_struct(response.test_result) if response.HasField("test_result") else {}
if on_complete:
on_complete(response.unique_id, response.output, response.status, result_dict)
final_results[response.unique_id] = result_dict
return final_results
@handle_errors
def run_cli_script(self, nodes_filter, script_path, parallel=10):
req = connpy_pb2.ScriptRequest(param1=nodes_filter, param2=script_path, parallel=parallel)
return from_struct(self.stub.run_cli_script(req).data)
@handle_errors
def run_yaml_playbook(self, playbook_path, parallel=10):
req = connpy_pb2.ScriptRequest(param1=playbook_path, parallel=parallel)
return from_struct(self.stub.run_yaml_playbook(req).data)
class ImportExportStub:
def __init__(self, channel, remote_host):
self.stub = connpy_pb2_grpc.ImportExportServiceStub(channel)
self.remote_host = remote_host
@handle_errors
def export_to_file(self, file_path, folders=None):
req = connpy_pb2.ExportRequest(file_path=file_path, folders=folders or [])
self.stub.export_to_file(req)
@handle_errors
def import_from_file(self, file_path):
with open(file_path, "r") as f:
content = f.read()
# Marker to tell the server this is content, not a path
marker_content = f"---YAML---\n{content}"
self.stub.import_from_file(connpy_pb2.StringRequest(value=marker_content))
@handle_errors
def set_reserved_names(self, names):
self.stub.set_reserved_names(connpy_pb2.ListRequest(items=names))
class AIStub:
def __init__(self, channel, remote_host):
self.stub = connpy_pb2_grpc.AIServiceStub(channel)
self.remote_host = remote_host
@handle_errors
def ask(self, input_text, dryrun=False, chat_history=None, session_id=None, debug=False, status=None, **overrides):
import queue
from rich.prompt import Prompt
from rich.text import Text
from rich.live import Live
from rich.panel import Panel
from rich.markdown import Markdown
req_queue = queue.Queue()
initial_req = connpy_pb2.AskRequest(
input_text=input_text,
dryrun=dryrun,
session_id=session_id or "",
debug=debug,
engineer_model=overrides.get("engineer_model", ""),
engineer_api_key=overrides.get("engineer_api_key", ""),
architect_model=overrides.get("architect_model", ""),
architect_api_key=overrides.get("architect_api_key", ""),
trust=overrides.get("trust", False)
)
if chat_history is not None:
initial_req.chat_history.CopyFrom(to_value(chat_history))
req_queue.put(initial_req)
def request_generator():
while True:
req = req_queue.get()
if req is None: break
yield req
responses = self.stub.ask(request_generator())
full_content = ""
live_display = None
final_result = {"response": "", "chat_history": []}
# Background thread to pull responses from gRPC into a local queue
# This prevents KeyboardInterrupt from corrupting the gRPC iterator state
response_queue = queue.Queue()
def pull_responses():
try:
for response in responses:
response_queue.put(("data", response))
except Exception as e:
response_queue.put(("error", e))
finally:
response_queue.put((None, None))
threading.Thread(target=pull_responses, daemon=True).start()
try:
while True:
try:
# BLOCKING GET from local queue (interruptible by signal)
msg_type, response = response_queue.get()
except KeyboardInterrupt:
# Signal interruption to the server
if status:
status.update("[error]Interrupted! Closing pending tasks...")
# Send the interrupt signal to the server
req_queue.put(connpy_pb2.AskRequest(interrupt=True))
# CONTINUE the loop to receive remaining data and summary from the queue
continue
if msg_type is None: # Sentinel
break
if msg_type == "error":
# Re-raise or handle gRPC error from background thread
if isinstance(response, grpc.RpcError):
raise response
printer.warning(f"Stream interrupted: {response}")
break
if response.status_update:
if response.requires_confirmation:
if status: status.stop()
if live_display: live_display.stop()
# Show prompt and wait for answer
prompt_text = Text.from_ansi(response.status_update)
ans = Prompt.ask(prompt_text)
if status:
status.update("[ai_status]Agent: Resuming...")
status.start()
if live_display: live_display.start()
req_queue.put(connpy_pb2.AskRequest(confirmation_answer=ans))
continue
if status:
status.update(response.status_update)
continue
if response.debug_message:
if debug:
printer.console.print(Text.from_ansi(response.debug_message))
continue
if response.important_message:
printer.console.print(Text.from_ansi(response.important_message))
continue
if not response.is_final:
full_content += response.text_chunk
if not live_display and not debug:
if status: status.stop()
live_display = Live(
Panel(Markdown(full_content), title="AI Assistant", expand=False),
console=printer.console,
refresh_per_second=8,
transient=False
)
live_display.start()
elif live_display:
live_display.update(Panel(Markdown(full_content), title="AI Assistant", expand=False))
continue
if response.is_final:
final_result = from_struct(response.full_result)
responder = final_result.get("responder", "engineer")
alias = "architect" if responder == "architect" else "engineer"
role_label = "Network Architect" if responder == "architect" else "Network Engineer"
title = f"[bold {alias}]{role_label}[/bold {alias}]"
if live_display:
live_display.update(Panel(Markdown(full_content), title=title, border_style=alias, expand=False))
live_display.stop()
elif full_content:
printer.console.print(Panel(Markdown(full_content), title=title, border_style=alias, expand=False))
break
except Exception as e:
# Check if it was a gRPC error that we should let handle_errors catch
if isinstance(e, grpc.RpcError):
raise
printer.warning(f"Stream interrupted: {e}")
finally:
req_queue.put(None)
if full_content:
final_result["streamed"] = True
return final_result
@handle_errors
def confirm(self, input_text, console=None):
return self.stub.confirm(connpy_pb2.StringRequest(value=input_text)).value
@handle_errors
def list_sessions(self):
return from_value(self.stub.list_sessions(Empty()).data)
@handle_errors
def delete_session(self, session_id):
self.stub.delete_session(connpy_pb2.StringRequest(value=session_id))
@handle_errors
def configure_provider(self, provider, model=None, api_key=None):
req = connpy_pb2.ProviderRequest(provider=provider, model=model or "", api_key=api_key or "")
self.stub.configure_provider(req)
@handle_errors
def load_session_data(self, session_id):
return from_struct(self.stub.load_session_data(connpy_pb2.StringRequest(value=session_id)).data)
class SystemStub:
def __init__(self, channel, remote_host):
self.stub = connpy_pb2_grpc.SystemServiceStub(channel)
self.remote_host = remote_host
@handle_errors
def start_api(self, port=None):
self.stub.start_api(connpy_pb2.IntRequest(value=port or 8048))
@handle_errors
def debug_api(self, port=None):
self.stub.debug_api(connpy_pb2.IntRequest(value=port or 8048))
@handle_errors
def stop_api(self):
self.stub.stop_api(Empty())
@handle_errors
def restart_api(self, port=None):
self.stub.restart_api(connpy_pb2.IntRequest(value=port or 8048))
@handle_errors
def get_api_status(self):
return self.stub.get_api_status(Empty()).value