feat(v6): integrate MCP support, async copilot streaming, and enhanced terminal UI logic
This commit is contained in:
@@ -50,6 +50,7 @@ coverage.xml
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
scratch/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
@@ -162,3 +163,4 @@ repo_consolidado_limpio.md
|
||||
connpy_roadmap.md
|
||||
MULTI_USER_PLAN.md
|
||||
COPILOT_PLAN.md
|
||||
ARCHITECTURAL_DEBT_REFACTOR.md
|
||||
|
||||
+144
-177
@@ -7,6 +7,7 @@ import threading
|
||||
import asyncio
|
||||
from textwrap import dedent
|
||||
from .core import nodes
|
||||
from .mcp_client import MCPClientManager
|
||||
|
||||
_litellm_initialized = False
|
||||
|
||||
@@ -144,6 +145,9 @@ class ai:
|
||||
self.engineer_prompt_extensions = [] # Extra text for engineer prompt
|
||||
self.architect_prompt_extensions = [] # Extra text for architect prompt
|
||||
|
||||
# MCP Manager
|
||||
self.mcp_manager = MCPClientManager(self.config)
|
||||
|
||||
# Long-term memory
|
||||
self.memory_path = os.path.join(self.config.defaultdir, "ai_memory.md")
|
||||
self.long_term_memory = ""
|
||||
@@ -677,7 +681,7 @@ class ai:
|
||||
self.console.print("[pass]✓ Trust Mode Enabled. All future commands in this session will execute without confirmation.[/pass]")
|
||||
elif user_resp_lower in ['y', 'yes']:
|
||||
self.console.print("[pass]✓ Executing...[/pass]")
|
||||
elif user_resp_lower in ['n', 'no', '']:
|
||||
elif user_resp_lower in ['n', 'no', '', 'cancel']:
|
||||
self.console.print("[fail]✗ Execution rejected by user.[/fail]")
|
||||
return "Error: User rejected execution."
|
||||
else:
|
||||
@@ -773,6 +777,10 @@ class ai:
|
||||
cmd_str = cmds[0] if cmds else ""
|
||||
status.update(f"[ai_status]Engineer: [CMD] {cmd_str}")
|
||||
elif fn == "get_node_info": status.update(f"[ai_status]Engineer: [INSPECT] {args.get('node_name','')}")
|
||||
elif fn.startswith("mcp_"):
|
||||
server = fn.split("__")[0].replace("mcp_", "")
|
||||
tool = fn.split("__")[1] if "__" in fn else fn
|
||||
status.update(f"[ai_status]Engineer: [MCP:{server}] {tool}")
|
||||
elif fn in self.tool_status_formatters: status.update(self.tool_status_formatters[fn](args))
|
||||
|
||||
if debug:
|
||||
@@ -781,6 +789,8 @@ class ai:
|
||||
if fn == "list_nodes": obs = self.list_nodes_tool(**args)
|
||||
elif fn == "run_commands": obs = self.run_commands_tool(**args, status=status)
|
||||
elif fn == "get_node_info": obs = self.get_node_info_tool(**args)
|
||||
elif fn.startswith("mcp_"):
|
||||
obs = run_ai_async(self.mcp_manager.call_tool(fn, args)).result(timeout=60)
|
||||
elif fn in self.external_tool_handlers: obs = self.external_tool_handlers[fn](self, **args)
|
||||
else: obs = f"Error: Unknown tool '{fn}'."
|
||||
|
||||
@@ -801,14 +811,22 @@ class ai:
|
||||
except Exception as e:
|
||||
return f"Engineer failed: {str(e)}", usage
|
||||
|
||||
def _get_engineer_tools(self):
|
||||
def _get_engineer_tools(self, os_filter: str = None):
|
||||
"""Define tools available to the Engineer."""
|
||||
base_tools = [
|
||||
{"type": "function", "function": {"name": "list_nodes", "description": "Lists available nodes in the inventory.", "parameters": {"type": "object", "properties": {"filter_pattern": {"type": "string", "description": "Regex to filter nodes (e.g. '.*', 'border.*')."}}}}},
|
||||
{"type": "function", "function": {"name": "run_commands", "description": "Runs one or more commands on matched nodes. MANDATORY: You MUST call 'list_nodes' first to verify the target list.", "parameters": {"type": "object", "properties": {"nodes_filter": {"type": "string", "description": "Exact node name or verified filter pattern."}, "commands": {"type": "array", "items": {"type": "string"}, "description": "List of commands (e.g. ['show ip route', 'show int desc'])."}}, "required": ["nodes_filter", "commands"]}}},
|
||||
{"type": "function", "function": {"name": "get_node_info", "description": "Gets full metadata for a specific node.", "parameters": {"type": "object", "properties": {"node_name": {"type": "string"}}, "required": ["node_name"]}}}
|
||||
{"type": "function", "function": {"name": "list_nodes", "description": "[Universal Platform] Lists available nodes in the inventory.", "parameters": {"type": "object", "properties": {"filter_pattern": {"type": "string", "description": "Regex to filter nodes (e.g. '.*', 'border.*')."}}}}},
|
||||
{"type": "function", "function": {"name": "run_commands", "description": "[Universal Platform] Runs one or more commands on matched nodes. MANDATORY: You MUST call 'list_nodes' first to verify the target list.", "parameters": {"type": "object", "properties": {"nodes_filter": {"type": "string", "description": "Exact node name or verified filter pattern."}, "commands": {"type": "array", "items": {"type": "string"}, "description": "List of commands (e.g. ['show ip route', 'show int desc'])."}}, "required": ["nodes_filter", "commands"]}}},
|
||||
{"type": "function", "function": {"name": "get_node_info", "description": "[Universal Platform] Gets full metadata for a specific node.", "parameters": {"type": "object", "properties": {"node_name": {"type": "string"}}, "required": ["node_name"]}}}
|
||||
]
|
||||
|
||||
# Add dynamic tools from MCP
|
||||
try:
|
||||
mcp_tools = run_ai_async(self.mcp_manager.get_tools_for_llm(os_filter=os_filter)).result(timeout=10)
|
||||
base_tools.extend(mcp_tools)
|
||||
except Exception as e:
|
||||
# Silently fail for LLM tools
|
||||
pass
|
||||
|
||||
if self.architect_key:
|
||||
base_tools.extend([
|
||||
{"type": "function", "function": {"name": "consult_architect", "description": "Ask the Strategic Reasoning Engine for advice on complex design, architecture, or troubleshooting decisions. You remain in control and will present the response to the user. Use this for: configuration planning, design validation, complex troubleshooting.", "parameters": {"type": "object", "properties": {"question": {"type": "string", "description": "Strategic question or decision needed."}, "technical_summary": {"type": "string", "description": "Technical findings and context gathered so far."}}, "required": ["question", "technical_summary"]}}},
|
||||
@@ -1202,6 +1220,8 @@ class ai:
|
||||
elif fn == "run_commands": obs = self.run_commands_tool(**args, status=status)
|
||||
elif fn == "get_node_info": obs = self.get_node_info_tool(**args)
|
||||
elif fn == "manage_memory_tool": obs = self.manage_memory_tool(**args)
|
||||
elif fn.startswith("mcp_"):
|
||||
obs = run_ai_async(self.mcp_manager.call_tool(fn, args)).result(timeout=60)
|
||||
elif fn in self.external_tool_handlers: obs = self.external_tool_handlers[fn](self, **args)
|
||||
else: obs = f"Error: {fn} unknown."
|
||||
|
||||
@@ -1268,146 +1288,6 @@ class ai:
|
||||
"streamed": streamed_response
|
||||
}
|
||||
|
||||
@MethodHook
|
||||
def ask_copilot(self, terminal_buffer, user_question, node_info=None, chunk_callback=None):
|
||||
"""Single-shot copilot for augmented terminal sessions.
|
||||
|
||||
Args:
|
||||
terminal_buffer: Sanitized terminal screen content (últimas N líneas).
|
||||
user_question: Pregunta del usuario sobre la sesión activa.
|
||||
node_info: Optional dict con metadata del nodo (os, name, etc.)
|
||||
chunk_callback: Optional callable for streaming the guide.
|
||||
|
||||
Returns:
|
||||
dict: {commands: list[str], guide: str, risk_level: str, error: str|None}
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
|
||||
node_info = node_info or {}
|
||||
os_info = node_info.get("os", "unknown")
|
||||
node_name = node_info.get("name", "unknown")
|
||||
|
||||
# Load vendor-specific command reference if available
|
||||
vendor_reference = ""
|
||||
if os_info and os_info != "unknown":
|
||||
try:
|
||||
os_filename = os_info.lower().replace(" ", "_")
|
||||
ref_path = os.path.join(self.config.defaultdir, "ai_references", f"{os_filename}.md")
|
||||
if os.path.exists(ref_path):
|
||||
with open(ref_path, "r") as f:
|
||||
vendor_reference = f.read().strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
system_prompt = f"""Role: TERMINAL COPILOT. You assist a network engineer during a live SSH session.
|
||||
Rules:
|
||||
1. Answer the user's question directly based on the Terminal Context.
|
||||
2. If the user asks you to analyze, parse, or extract data from the Terminal Context, DO IT directly in the <guide> section (you can use markdown tables or lists). Do NOT just give them a command to do it themselves.
|
||||
3. If the user wants to execute an action, provide the required CLI commands inside a <commands> block, one command per line. If no commands are needed, leave it empty or omit the block.
|
||||
4. ULTRA-CONCISE. Keep your guide to the point.
|
||||
5. You MUST output your response in the following strict format:
|
||||
<guide>
|
||||
Your brief tactical guide in markdown. 3-4 sentences max.
|
||||
</guide>
|
||||
<commands>
|
||||
command 1
|
||||
command 2
|
||||
</commands>
|
||||
<risk>
|
||||
low, high, or destructive
|
||||
</risk>
|
||||
6. Risk level: "low" for read-only/no commands, "high" for config changes, "destructive" for potentially dangerous ops.
|
||||
|
||||
Terminal Context:
|
||||
{terminal_buffer}
|
||||
|
||||
Device OS: {os_info}
|
||||
Node: {node_name}"""
|
||||
|
||||
if vendor_reference:
|
||||
system_prompt += f"\n\nVendor Command Reference:\n{vendor_reference}"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_question}
|
||||
]
|
||||
|
||||
try:
|
||||
response = completion(
|
||||
model=self.engineer_model,
|
||||
messages=messages,
|
||||
api_key=self.engineer_key,
|
||||
stream=True
|
||||
)
|
||||
|
||||
full_content = ""
|
||||
streamed_guide = ""
|
||||
|
||||
for chunk in response:
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(delta, 'content') and delta.content:
|
||||
full_content += delta.content
|
||||
|
||||
if chunk_callback:
|
||||
start_idx = full_content.find("<guide>")
|
||||
if start_idx != -1:
|
||||
after_start = full_content[start_idx + 7:]
|
||||
end_idx = after_start.find("</guide>")
|
||||
|
||||
if end_idx != -1:
|
||||
current_guide = after_start[:end_idx]
|
||||
else:
|
||||
current_guide = after_start
|
||||
if current_guide.endswith("<"): current_guide = current_guide[:-1]
|
||||
elif current_guide.endswith("</"): current_guide = current_guide[:-2]
|
||||
elif current_guide.endswith("</g"): current_guide = current_guide[:-3]
|
||||
elif current_guide.endswith("</gu"): current_guide = current_guide[:-4]
|
||||
elif current_guide.endswith("</gui"): current_guide = current_guide[:-5]
|
||||
elif current_guide.endswith("</guid"): current_guide = current_guide[:-6]
|
||||
elif current_guide.endswith("</guide"): current_guide = current_guide[:-7]
|
||||
|
||||
new_text = current_guide[len(streamed_guide):]
|
||||
if new_text:
|
||||
chunk_callback(new_text)
|
||||
streamed_guide += new_text
|
||||
|
||||
guide = ""
|
||||
commands = []
|
||||
risk_level = "low"
|
||||
|
||||
guide_match = re.search(r"<guide>(.*?)</guide>", full_content, re.DOTALL)
|
||||
if guide_match:
|
||||
guide = guide_match.group(1).strip()
|
||||
|
||||
cmd_match = re.search(r"<commands>(.*?)</commands>", full_content, re.DOTALL)
|
||||
if cmd_match:
|
||||
cmds_raw = cmd_match.group(1).strip()
|
||||
if cmds_raw:
|
||||
commands = [c.strip() for c in cmds_raw.split('\n') if c.strip()]
|
||||
|
||||
risk_match = re.search(r"<risk>(.*?)</risk>", full_content, re.DOTALL)
|
||||
if risk_match:
|
||||
risk_level = risk_match.group(1).strip().lower()
|
||||
|
||||
if not guide and full_content and not ("<guide>" in full_content):
|
||||
guide = full_content.strip()
|
||||
|
||||
return {
|
||||
"commands": commands,
|
||||
"guide": guide,
|
||||
"risk_level": risk_level,
|
||||
"error": None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"commands": [],
|
||||
"guide": "",
|
||||
"risk_level": "low",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
@MethodHook
|
||||
async def aask_copilot(self, terminal_buffer, user_question, node_info=None, chunk_callback=None):
|
||||
import json
|
||||
@@ -1463,49 +1343,136 @@ Node: {node_name}"""
|
||||
if vendor_reference:
|
||||
system_prompt += f"\n\nVendor Command Reference:\n{vendor_reference}"
|
||||
|
||||
# Fetch MCP tools for the current OS
|
||||
mcp_tools = []
|
||||
try:
|
||||
mcp_tools = await self.mcp_manager.get_tools_for_llm(os_filter=os_info)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if mcp_tools:
|
||||
system_prompt += f"\n\nAvailable MCP Tools: {', '.join([t['function']['name'] for t in mcp_tools])}"
|
||||
system_prompt += "\nUse these tools to validate syntax or find exact commands if needed before providing the final guide."
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_question}
|
||||
]
|
||||
|
||||
iteration = 0
|
||||
max_iterations = 5 # Allow up to 5 iterations for tool usage
|
||||
|
||||
try:
|
||||
response = await acompletion(
|
||||
model=self.engineer_model,
|
||||
messages=messages,
|
||||
api_key=self.engineer_key,
|
||||
stream=True
|
||||
)
|
||||
while iteration < max_iterations:
|
||||
iteration += 1
|
||||
response = await acompletion(
|
||||
model=self.engineer_model,
|
||||
messages=messages,
|
||||
tools=mcp_tools if mcp_tools else None,
|
||||
api_key=self.engineer_key,
|
||||
stream=True
|
||||
)
|
||||
|
||||
full_content = ""
|
||||
streamed_guide = ""
|
||||
full_content = ""
|
||||
streamed_guide = ""
|
||||
tool_calls = []
|
||||
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(delta, 'content') and delta.content:
|
||||
full_content += delta.content
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
if chunk_callback:
|
||||
start_idx = full_content.find("<guide>")
|
||||
if start_idx != -1:
|
||||
after_start = full_content[start_idx + 7:]
|
||||
end_idx = after_start.find("</guide>")
|
||||
|
||||
if end_idx != -1:
|
||||
current_guide = after_start[:end_idx]
|
||||
# Accumulate tool calls
|
||||
if hasattr(delta, 'tool_calls') and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx >= len(tool_calls):
|
||||
tool_calls.append({"id": tc.id, "type": "function", "function": {"name": tc.function.name or "", "arguments": tc.function.arguments or ""}})
|
||||
else:
|
||||
current_guide = after_start
|
||||
if current_guide.endswith("<"): current_guide = current_guide[:-1]
|
||||
elif current_guide.endswith("</"): current_guide = current_guide[:-2]
|
||||
elif current_guide.endswith("</g"): current_guide = current_guide[:-3]
|
||||
elif current_guide.endswith("</gu"): current_guide = current_guide[:-4]
|
||||
elif current_guide.endswith("</gui"): current_guide = current_guide[:-5]
|
||||
elif current_guide.endswith("</guid"): current_guide = current_guide[:-6]
|
||||
elif current_guide.endswith("</guide"): current_guide = current_guide[:-7]
|
||||
if tc.id: tool_calls[idx]["id"] = tc.id
|
||||
if tc.function.name: tool_calls[idx]["function"]["name"] = tc.function.name
|
||||
if tc.function.arguments: tool_calls[idx]["function"]["arguments"] += tc.function.arguments
|
||||
|
||||
new_text = current_guide[len(streamed_guide):]
|
||||
if new_text:
|
||||
chunk_callback(new_text)
|
||||
streamed_guide += new_text
|
||||
if hasattr(delta, 'content') and delta.content:
|
||||
full_content += delta.content
|
||||
|
||||
if chunk_callback and not tool_calls: # Only stream if not using tools
|
||||
start_idx = full_content.find("<guide>")
|
||||
if start_idx != -1:
|
||||
after_start = full_content[start_idx + 7:]
|
||||
end_idx = after_start.find("</guide>")
|
||||
|
||||
if end_idx != -1:
|
||||
current_guide = after_start[:end_idx]
|
||||
else:
|
||||
current_guide = after_start
|
||||
if current_guide.endswith("<"): current_guide = current_guide[:-1]
|
||||
elif current_guide.endswith("</"): current_guide = current_guide[:-2]
|
||||
elif current_guide.endswith("</g"): current_guide = current_guide[:-3]
|
||||
elif current_guide.endswith("</gu"): current_guide = current_guide[:-4]
|
||||
elif current_guide.endswith("</gui"): current_guide = current_guide[:-5]
|
||||
elif current_guide.endswith("</guid"): current_guide = current_guide[:-6]
|
||||
elif current_guide.endswith("</guide"): current_guide = current_guide[:-7]
|
||||
|
||||
new_text = current_guide[len(streamed_guide):]
|
||||
if new_text:
|
||||
chunk_callback(new_text)
|
||||
streamed_guide += new_text
|
||||
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
# Execute tool calls
|
||||
messages.append({"role": "assistant", "content": full_content or None, "tool_calls": tool_calls})
|
||||
for tc in tool_calls:
|
||||
fn = tc["function"]["name"]
|
||||
args = json.loads(tc["function"]["arguments"])
|
||||
|
||||
if "mcp_" in fn:
|
||||
try:
|
||||
obs = await asyncio.wait_for(self.mcp_manager.call_tool(fn, args), timeout=30.0)
|
||||
except Exception as e:
|
||||
obs = f"Error calling MCP tool: {e}"
|
||||
else:
|
||||
obs = f"Error: Tool {fn} not allowed in Copilot."
|
||||
|
||||
messages.append({"tool_call_id": tc["id"], "role": "tool", "name": fn, "content": self._truncate(str(obs))})
|
||||
|
||||
# If we hit the limit and it was still using tools, force a final answer
|
||||
if tool_calls and iteration >= max_iterations:
|
||||
messages.append({"role": "user", "content": "Tool limit reached. Provide your final tactical guide now based on the findings."})
|
||||
response = await acompletion(
|
||||
model=self.engineer_model,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
api_key=self.engineer_key,
|
||||
stream=True
|
||||
)
|
||||
|
||||
full_content = ""
|
||||
streamed_guide = ""
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(delta, 'content') and delta.content:
|
||||
full_content += delta.content
|
||||
if chunk_callback:
|
||||
start_idx = full_content.find("<guide>")
|
||||
if start_idx != -1:
|
||||
after_start = full_content[start_idx + 7:]
|
||||
end_idx = after_start.find("</guide>")
|
||||
if end_idx != -1:
|
||||
current_guide = after_start[:end_idx]
|
||||
else:
|
||||
current_guide = after_start
|
||||
if current_guide.endswith("<"): current_guide = current_guide[:-1]
|
||||
elif current_guide.endswith("</"): current_guide = current_guide[:-2]
|
||||
elif current_guide.endswith("</g"): current_guide = current_guide[:-3]
|
||||
elif current_guide.endswith("</gu"): current_guide = current_guide[:-4]
|
||||
elif current_guide.endswith("</gui"): current_guide = current_guide[:-5]
|
||||
elif current_guide.endswith("</guid"): current_guide = current_guide[:-6]
|
||||
elif current_guide.endswith("</guide"): current_guide = current_guide[:-7]
|
||||
new_text = current_guide[len(streamed_guide):]
|
||||
if new_text:
|
||||
chunk_callback(new_text)
|
||||
streamed_guide += new_text
|
||||
|
||||
guide = ""
|
||||
commands = []
|
||||
|
||||
+116
-1
@@ -32,6 +32,9 @@ class AIHandler:
|
||||
printer.error(str(e))
|
||||
return
|
||||
|
||||
if args.mcp is not None:
|
||||
return self.configure_mcp(args)
|
||||
|
||||
# Determinar session_id para retomar
|
||||
session_id = None
|
||||
if args.resume:
|
||||
@@ -110,7 +113,7 @@ class AIHandler:
|
||||
try:
|
||||
user_query = Prompt.ask("[user_prompt]User[/user_prompt]")
|
||||
if not user_query.strip(): continue
|
||||
if user_query.lower() in ['exit', 'quit', 'bye']: break
|
||||
if user_query.lower() in ['exit', 'quit', 'bye', 'cancel']: break
|
||||
|
||||
with console.status("[ai_status]Agent is thinking...") as status:
|
||||
result = self.app.myai.ask(user_query, chat_history=history, status=status, debug=args.debug, trust=args.trust, **self.ai_overrides)
|
||||
@@ -134,3 +137,115 @@ class AIHandler:
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
console.print("\n[dim]Session closed.[/dim]")
|
||||
break
|
||||
|
||||
def configure_mcp(self, args):
|
||||
"""Handle MCP server configuration via CLI tokens or interactive wizard."""
|
||||
mcp_args = args.mcp
|
||||
|
||||
# 1. Non-interactive CLI Mode (if arguments are provided)
|
||||
if mcp_args:
|
||||
action = mcp_args[0].lower()
|
||||
|
||||
if action == "list":
|
||||
settings = self.app.services.config_svc.get_settings()
|
||||
mcp_servers = settings.get("ai", {}).get("mcp_servers", {})
|
||||
if not mcp_servers:
|
||||
printer.info("No MCP servers configured.")
|
||||
else:
|
||||
columns = ["Name", "URL", "Enabled", "Auto-load OS"]
|
||||
rows = []
|
||||
for name, cfg in mcp_servers.items():
|
||||
rows.append([
|
||||
name,
|
||||
cfg.get("url", ""),
|
||||
"[green]Yes[/green]" if cfg.get("enabled", True) else "[red]No[/red]",
|
||||
cfg.get("auto_load_on_os", "Any")
|
||||
])
|
||||
printer.table("Configured MCP Servers", columns, rows)
|
||||
return
|
||||
|
||||
elif action == "add":
|
||||
if len(mcp_args) < 3:
|
||||
printer.error("Usage: connpy ai --mcp add <name> <url> [os_filter]")
|
||||
return
|
||||
name, url = mcp_args[1], mcp_args[2]
|
||||
os_filter = mcp_args[3] if len(mcp_args) > 3 else None
|
||||
try:
|
||||
self.app.services.ai.configure_mcp(name, url=url, auto_load_on_os=os_filter)
|
||||
printer.success(f"MCP server '{name}' added/updated.")
|
||||
except Exception as e:
|
||||
printer.error(str(e))
|
||||
return
|
||||
|
||||
elif action == "remove":
|
||||
if len(mcp_args) < 2:
|
||||
printer.error("Usage: connpy ai --mcp remove <name>")
|
||||
return
|
||||
name = mcp_args[1]
|
||||
try:
|
||||
self.app.services.ai.configure_mcp(name, remove=True)
|
||||
printer.success(f"MCP server '{name}' removed.")
|
||||
except Exception as e:
|
||||
printer.error(str(e))
|
||||
return
|
||||
|
||||
elif action in ["enable", "disable"]:
|
||||
if len(mcp_args) < 2:
|
||||
printer.error(f"Usage: connpy ai --mcp {action} <name>")
|
||||
return
|
||||
name = mcp_args[1]
|
||||
enabled = (action == "enable")
|
||||
try:
|
||||
self.app.services.ai.configure_mcp(name, enabled=enabled)
|
||||
printer.success(f"MCP server '{name}' {'enabled' if enabled else 'disabled'}.")
|
||||
except Exception as e:
|
||||
printer.error(str(e))
|
||||
return
|
||||
|
||||
else:
|
||||
printer.error(f"Unknown MCP action: {action}")
|
||||
printer.info("Available actions: list, add, remove, enable, disable")
|
||||
return
|
||||
|
||||
# 2. Interactive Wizard Mode (if no arguments provided)
|
||||
# Import forms dynamically to avoid circular dependencies if any
|
||||
if not hasattr(self.app, "cli_forms"):
|
||||
from .forms import Forms
|
||||
self.app.cli_forms = Forms(self.app)
|
||||
|
||||
settings = self.app.services.config_svc.get_settings()
|
||||
mcp_servers = settings.get("ai", {}).get("mcp_servers", {})
|
||||
|
||||
result = self.app.cli_forms.mcp_wizard(mcp_servers)
|
||||
if not result:
|
||||
return
|
||||
|
||||
action = result["action"]
|
||||
try:
|
||||
if action == "list":
|
||||
# Recursive call to the non-interactive list logic
|
||||
args.mcp = ["list"]
|
||||
return self.configure_mcp(args)
|
||||
|
||||
elif action == "add":
|
||||
self.app.services.ai.configure_mcp(
|
||||
result["name"],
|
||||
url=result["url"],
|
||||
enabled=result["enabled"],
|
||||
auto_load_on_os=result["os"]
|
||||
)
|
||||
printer.success(f"MCP server '{result['name']}' saved.")
|
||||
|
||||
elif action == "update": # Used for toggle
|
||||
self.app.services.ai.configure_mcp(
|
||||
result["name"],
|
||||
enabled=result["enabled"]
|
||||
)
|
||||
printer.success(f"MCP server '{result['name']}' updated.")
|
||||
|
||||
elif action == "remove":
|
||||
self.app.services.ai.configure_mcp(result["name"], remove=True)
|
||||
printer.success(f"MCP server '{result['name']}' removed.")
|
||||
|
||||
except Exception as e:
|
||||
printer.error(str(e))
|
||||
|
||||
@@ -197,3 +197,84 @@ class Forms:
|
||||
answer["tags"] = ast.literal_eval(answer["tags"])
|
||||
|
||||
return answer
|
||||
|
||||
def mcp_wizard(self, mcp_servers):
|
||||
"""Interactive wizard to manage MCP servers."""
|
||||
from .helpers import theme
|
||||
|
||||
while True:
|
||||
options = [
|
||||
("List Configured Servers", "list"),
|
||||
("Add/Update Server", "add"),
|
||||
("Enable/Disable Server", "toggle"),
|
||||
("Remove Server", "remove"),
|
||||
("Back", "exit")
|
||||
]
|
||||
|
||||
questions = [
|
||||
inquirer.List("action", message="MCP Configuration", choices=options)
|
||||
]
|
||||
|
||||
answers = inquirer.prompt(questions, theme=theme)
|
||||
if not answers or answers["action"] == "exit":
|
||||
return None
|
||||
|
||||
action = answers["action"]
|
||||
|
||||
if action == "list":
|
||||
if not mcp_servers:
|
||||
print("\nNo MCP servers configured.\n")
|
||||
else:
|
||||
return {"action": "list"}
|
||||
|
||||
elif action == "add":
|
||||
questions = [
|
||||
inquirer.Text("name", message="Server Name (identifier)"),
|
||||
inquirer.Text("url", message="SSE URL (e.g., http://localhost:8000/sse)"),
|
||||
inquirer.Confirm("enabled", message="Enabled?", default=True),
|
||||
inquirer.Text("auto_load_os", message="Auto-load on specific OS (blank for any)")
|
||||
]
|
||||
answers = inquirer.prompt(questions, theme=theme)
|
||||
if answers:
|
||||
return {
|
||||
"action": "add",
|
||||
"name": answers["name"],
|
||||
"url": answers["url"],
|
||||
"enabled": answers["enabled"],
|
||||
"os": answers["auto_load_os"]
|
||||
}
|
||||
|
||||
elif action == "toggle":
|
||||
if not mcp_servers:
|
||||
print("\nNo servers to toggle.\n")
|
||||
continue
|
||||
|
||||
choices = []
|
||||
for name, cfg in mcp_servers.items():
|
||||
status = "[Enabled]" if cfg.get("enabled", True) else "[Disabled]"
|
||||
choices.append((f"{name} {status}", name))
|
||||
|
||||
questions = [
|
||||
inquirer.List("name", message="Select server to toggle", choices=choices + [("Cancel", None)])
|
||||
]
|
||||
answers = inquirer.prompt(questions, theme=theme)
|
||||
if answers and answers["name"]:
|
||||
current = mcp_servers[answers["name"]].get("enabled", True)
|
||||
return {
|
||||
"action": "update",
|
||||
"name": answers["name"],
|
||||
"enabled": not current
|
||||
}
|
||||
|
||||
elif action == "remove":
|
||||
if not mcp_servers:
|
||||
print("\nNo servers to remove.\n")
|
||||
continue
|
||||
|
||||
questions = [
|
||||
inquirer.List("name", message="Select server to remove", choices=list(mcp_servers.keys()) + ["Cancel"])
|
||||
]
|
||||
answers = inquirer.prompt(questions, theme=theme)
|
||||
if answers and answers["name"] != "Cancel":
|
||||
return {"action": "remove", "name": answers["name"]}
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,307 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import asyncio
|
||||
import fcntl
|
||||
import termios
|
||||
import tty
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
from textwrap import dedent
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from rich.live import Live
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
from prompt_toolkit.history import InMemoryHistory
|
||||
|
||||
from ..printer import connpy_theme
|
||||
|
||||
def log_cleaner(data: str) -> str:
|
||||
"""
|
||||
Stateless version of _logclean to remove ANSI sequences and process cursor movements.
|
||||
"""
|
||||
if not data:
|
||||
return ""
|
||||
|
||||
lines = data.split('\n')
|
||||
cleaned_lines = []
|
||||
|
||||
# Regex to capture: ANSI sequences, control characters (\r, \b, etc), and plain text chunks
|
||||
token_re = re.compile(r'(\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/ ]*[@-~])|\r|\b|\x7f|[\x00-\x1F]|[^\x1B\r\b\x7f\x00-\x1F]+)')
|
||||
|
||||
for line in lines:
|
||||
buffer = []
|
||||
cursor = 0
|
||||
|
||||
for token in token_re.findall(line):
|
||||
if token == '\r':
|
||||
cursor = 0
|
||||
elif token in ('\b', '\x7f'):
|
||||
if cursor > 0:
|
||||
cursor -= 1
|
||||
elif token == '\x1B[D': # Left Arrow
|
||||
if cursor > 0:
|
||||
cursor -= 1
|
||||
elif token == '\x1B[C': # Right Arrow
|
||||
if cursor < len(buffer):
|
||||
cursor += 1
|
||||
elif token == '\x1B[K': # Clear to end of line
|
||||
buffer = buffer[:cursor]
|
||||
elif token.startswith('\x1B'):
|
||||
continue
|
||||
elif len(token) == 1 and ord(token) < 32:
|
||||
continue
|
||||
else:
|
||||
for char in token:
|
||||
if cursor == len(buffer):
|
||||
buffer.append(char)
|
||||
else:
|
||||
buffer[cursor] = char
|
||||
cursor += 1
|
||||
cleaned_lines.append("".join(buffer))
|
||||
|
||||
return "\n".join(cleaned_lines).replace('\n\n', '\n').strip()
|
||||
|
||||
class CopilotInterface:
|
||||
def __init__(self, config, history=None):
|
||||
self.config = config
|
||||
self.console = Console(theme=connpy_theme)
|
||||
self.history = history or InMemoryHistory()
|
||||
self.mode_range, self.mode_single, self.mode_lines = 0, 1, 2
|
||||
|
||||
def extract_blocks(self, raw_bytes: bytes, cmd_byte_positions: List[tuple], node_info: dict) -> List[tuple]:
|
||||
"""Identifies command blocks in the terminal history."""
|
||||
blocks = []
|
||||
if not (cmd_byte_positions and len(cmd_byte_positions) >= 2 and raw_bytes):
|
||||
return blocks
|
||||
|
||||
default_prompt = r'>$|#$|\$$|>.$|#.$|\$.$'
|
||||
device_prompt = node_info.get("prompt", default_prompt) if isinstance(node_info, dict) else default_prompt
|
||||
prompt_re_str = re.sub(r'(?<!\\)\$', '', device_prompt)
|
||||
try:
|
||||
prompt_re = re.compile(prompt_re_str)
|
||||
except Exception:
|
||||
prompt_re = re.compile(re.sub(r'(?<!\\)\$', '', default_prompt))
|
||||
|
||||
for i in range(1, len(cmd_byte_positions)):
|
||||
pos, known_cmd = cmd_byte_positions[i]
|
||||
prev_pos = cmd_byte_positions[i-1][0]
|
||||
|
||||
if known_cmd:
|
||||
prev_chunk = raw_bytes[prev_pos:pos]
|
||||
prev_cleaned = log_cleaner(prev_chunk.decode(errors='replace'))
|
||||
prev_lines = [l for l in prev_cleaned.split('\n') if l.strip()]
|
||||
prompt_text = prev_lines[-1].strip() if prev_lines else ""
|
||||
preview = f"{prompt_text}{known_cmd}" if prompt_text else known_cmd
|
||||
blocks.append((pos, preview[:80]))
|
||||
else:
|
||||
chunk = raw_bytes[prev_pos:pos]
|
||||
cleaned = log_cleaner(chunk.decode(errors='replace'))
|
||||
lines = [l for l in cleaned.split('\n') if l.strip()]
|
||||
preview = lines[-1].strip() if lines else ""
|
||||
|
||||
if preview:
|
||||
match = prompt_re.search(preview)
|
||||
if match:
|
||||
cmd_text = preview[match.end():].strip()
|
||||
if cmd_text:
|
||||
blocks.append((pos, preview[:80]))
|
||||
return blocks
|
||||
|
||||
async def run_session(self,
|
||||
raw_bytes: bytes,
|
||||
cmd_byte_positions: List[tuple],
|
||||
node_info: dict,
|
||||
on_ai_call: Callable):
|
||||
"""
|
||||
Runs the interactive Copilot session.
|
||||
on_ai_call: async function(active_buffer, question) -> result_dict
|
||||
"""
|
||||
from rich.rule import Rule
|
||||
|
||||
try:
|
||||
# Prepare UI state
|
||||
buffer = log_cleaner(raw_bytes.decode(errors='replace'))
|
||||
blocks = self.extract_blocks(raw_bytes, cmd_byte_positions, node_info)
|
||||
last_line = buffer.split('\n')[-1].strip() if buffer.strip() else "(prompt)"
|
||||
blocks.append((len(raw_bytes), last_line[:80]))
|
||||
|
||||
state = {
|
||||
'context_cmd': 1,
|
||||
'total_cmds': len(blocks),
|
||||
'total_lines': len(buffer.split('\n')),
|
||||
'context_lines': min(50, len(buffer.split('\n'))),
|
||||
'context_mode': self.mode_range,
|
||||
'cancelled': False
|
||||
}
|
||||
|
||||
# 1. Visual Separation
|
||||
self.console.print("") # Salto de línea real
|
||||
self.console.print(Rule(title="[bold cyan] AI TERMINAL COPILOT [/bold cyan]", style="cyan"))
|
||||
self.console.print(Panel(
|
||||
"[dim]Type your question. Enter to send, Escape/Ctrl+C to cancel.\n"
|
||||
"Tab to change context mode. Ctrl+\u2191/\u2193 to adjust context. \u2191\u2193 for question history.[/dim]",
|
||||
border_style="cyan"
|
||||
))
|
||||
self.console.print("\n") # Pequeño espacio antes del prompt del copilot
|
||||
|
||||
bindings = KeyBindings()
|
||||
@bindings.add('c-up')
|
||||
def _(event):
|
||||
if state['context_mode'] == self.mode_lines:
|
||||
state['context_lines'] = min(state['context_lines'] + 50, state['total_lines'])
|
||||
else:
|
||||
state['context_cmd'] = min(state['context_cmd'] + 1, state['total_cmds'])
|
||||
event.app.invalidate()
|
||||
@bindings.add('c-down')
|
||||
def _(event):
|
||||
if state['context_mode'] == self.mode_lines:
|
||||
state['context_lines'] = max(state['context_lines'] - 50, min(50, state['total_lines']))
|
||||
else:
|
||||
state['context_cmd'] = max(state['context_cmd'] - 1, 1)
|
||||
event.app.invalidate()
|
||||
@bindings.add('tab')
|
||||
def _(event):
|
||||
state['context_mode'] = (state['context_mode'] + 1) % 3
|
||||
event.app.invalidate()
|
||||
@bindings.add('escape', eager=True)
|
||||
@bindings.add('c-c')
|
||||
def _(event):
|
||||
state['cancelled'] = True
|
||||
event.app.exit(result='')
|
||||
|
||||
def get_active_buffer():
|
||||
if state['context_mode'] == self.mode_lines:
|
||||
return '\n'.join(buffer.split('\n')[-state['context_lines']:])
|
||||
idx = max(0, state['total_cmds'] - state['context_cmd'])
|
||||
start, preview = blocks[idx]
|
||||
if state['context_mode'] == self.mode_single and idx + 1 < state['total_cmds']:
|
||||
end = blocks[idx + 1][0]
|
||||
active_raw = raw_bytes[start:end]
|
||||
else:
|
||||
active_raw = raw_bytes[start:]
|
||||
return preview + "\n" + log_cleaner(active_raw.decode(errors='replace'))
|
||||
|
||||
def get_prompt_text():
|
||||
if state['context_mode'] == self.mode_lines:
|
||||
return HTML(f"<ansicyan>Ask [Ctx: {state['context_lines']}/{state['total_lines']}L]: </ansicyan>")
|
||||
active = get_active_buffer()
|
||||
lines_count = len(active.split('\n'))
|
||||
mode_str = {self.mode_range: "Range", self.mode_single: "Cmd"}[state['context_mode']]
|
||||
return HTML(f"<ansicyan>Ask [{mode_str} {state['context_cmd']} ~{lines_count}L]: </ansicyan>")
|
||||
|
||||
def get_toolbar():
|
||||
m_label = {self.mode_range: "RANGE", self.mode_single: "SINGLE", self.mode_lines: "LINES"}[state['context_mode']]
|
||||
if state['context_mode'] == self.mode_lines:
|
||||
return HTML(f"<ansigray>\u25b6 Ctrl+\u2191/\u2193 adjusts by 50 lines [Tab: {m_label}]</ansigray>")
|
||||
idx = max(0, state['total_cmds'] - state['context_cmd'])
|
||||
return HTML(f"<ansigray>\u25b6 {blocks[idx][1]} [Tab: {m_label}]</ansigray>")
|
||||
|
||||
# 2. Ask question
|
||||
session = PromptSession(history=self.history)
|
||||
try:
|
||||
# Usamos un try/finally interno para asegurar que si algo falla en prompt_async,
|
||||
# no nos quedemos con la terminal en un estado extraño.
|
||||
question = await session.prompt_async(
|
||||
get_prompt_text,
|
||||
key_bindings=bindings,
|
||||
bottom_toolbar=get_toolbar
|
||||
)
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
state['cancelled'] = True
|
||||
question = ""
|
||||
|
||||
if state['cancelled'] or not question.strip() or question.strip().lower() == 'cancel':
|
||||
return "cancel", None, None
|
||||
|
||||
# Enrich question
|
||||
past = self.history.get_strings()
|
||||
if len(past) > 1:
|
||||
history_text = "\n".join(f"- {q}" for q in past[-6:-1])
|
||||
question = f"Previous questions:\n{history_text}\n\nCurrent Question:\n{question}"
|
||||
|
||||
# 3. AI Execution
|
||||
active_buffer = get_active_buffer()
|
||||
live_text = "Thinking..."
|
||||
panel = Panel(live_text, title="[bold cyan]Copilot Guide[/bold cyan]", border_style="cyan")
|
||||
|
||||
def on_chunk(text):
|
||||
nonlocal live_text
|
||||
if live_text == "Thinking...": live_text = ""
|
||||
live_text += text
|
||||
|
||||
with Live(panel, console=self.console, refresh_per_second=10) as live:
|
||||
def update_live(t):
|
||||
live.update(Panel(Markdown(t), title="[bold cyan]Copilot Guide[/bold cyan]", border_style="cyan"))
|
||||
|
||||
wrapped_chunk = lambda t: (on_chunk(t), update_live(live_text))
|
||||
|
||||
# Check for interruption during AI call
|
||||
ai_task = asyncio.create_task(on_ai_call(active_buffer, question, wrapped_chunk))
|
||||
|
||||
try:
|
||||
while not ai_task.done():
|
||||
await asyncio.sleep(0.05)
|
||||
result = await ai_task
|
||||
except asyncio.CancelledError:
|
||||
return "cancel", None, None
|
||||
|
||||
if not result or result.get("error"):
|
||||
if result and result.get("error"): self.console.print(f"[red]Error: {result['error']}[/red]")
|
||||
return "cancel", None, None
|
||||
|
||||
# 4. Handle result
|
||||
if live_text == "Thinking..." and result.get("guide"):
|
||||
self.console.print(Panel(Markdown(result["guide"]), title="[bold cyan]Copilot Guide[/bold cyan]", border_style="cyan"))
|
||||
|
||||
commands = result.get("commands", [])
|
||||
if not commands:
|
||||
return "cancel", None, None
|
||||
|
||||
risk = result.get("risk_level", "low")
|
||||
style = {"low": "green", "high": "yellow", "destructive": "red"}.get(risk, "green")
|
||||
cmd_text = "\n".join(f" {i+1}. {c}" for i, c in enumerate(commands))
|
||||
self.console.print(Panel(cmd_text, title=f"[bold {style}]Suggested Commands [{risk.upper()}][/bold {style}]", border_style=style))
|
||||
|
||||
confirm_session = PromptSession()
|
||||
c_bindings = KeyBindings()
|
||||
@c_bindings.add('escape', eager=True)
|
||||
@c_bindings.add('c-c')
|
||||
def _(ev): ev.app.exit(result='n')
|
||||
|
||||
try:
|
||||
action = await confirm_session.prompt_async(HTML(f"<ansi{style}>Send? (y/n/e/number) [n]: </ansi{style}>"), key_bindings=c_bindings)
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
action = "n"
|
||||
|
||||
action_l = (action or "n").lower().strip()
|
||||
if action_l in ('y', 'yes', 'all'):
|
||||
return "send_all", commands, None
|
||||
elif action_l.startswith('e'):
|
||||
target = "\n".join(commands)
|
||||
e_bindings = KeyBindings()
|
||||
@e_bindings.add('c-j')
|
||||
def _(ev): ev.app.exit(result=ev.app.current_buffer.text)
|
||||
@e_bindings.add('escape', 'enter')
|
||||
def _(ev): ev.app.exit(result=ev.app.current_buffer.text)
|
||||
@e_bindings.add('escape')
|
||||
def _(ev): ev.app.exit(result='')
|
||||
|
||||
edited = await confirm_session.prompt_async(
|
||||
HTML("<ansicyan>Edit (Ctrl+Enter or Esc+Enter to submit):\n</ansicyan>"),
|
||||
default=target, multiline=True, key_bindings=e_bindings
|
||||
)
|
||||
if edited.strip():
|
||||
# Split by lines to ensure core.py applies delay between each command
|
||||
lines = [l.strip() for l in edited.split('\n') if l.strip()]
|
||||
return "custom", None, lines
|
||||
return "cancel", None, None
|
||||
|
||||
return "cancel", None, None
|
||||
|
||||
finally:
|
||||
self.console.print("[dim]Returning to session...[/dim]")
|
||||
|
||||
@@ -284,6 +284,7 @@ class connapp:
|
||||
aiparser.add_argument("--session", nargs=1, help="Resume a specific AI session by ID")
|
||||
aiparser.add_argument("--resume", action="store_true", help="Resume the most recent AI session")
|
||||
aiparser.add_argument("--delete", "--delete-session", dest="delete_session", nargs=1, help="Delete an AI session by ID")
|
||||
aiparser.add_argument("--mcp", nargs='*', metavar=('ACTION', 'NAME'), help="Manage MCP servers. Actions: list, add, remove, enable, disable. Leave empty for interactive wizard.")
|
||||
aiparser.set_defaults(func=self._ai.dispatch)
|
||||
#RUNPARSER
|
||||
runparser = subparsers.add_parser("run", help="Run scripts or commands on nodes", description="Run scripts or commands on nodes", formatter_class=RichHelpFormatter)
|
||||
|
||||
+92
-499
@@ -257,61 +257,29 @@ class node:
|
||||
|
||||
@MethodHook
|
||||
def _logclean(self, logfile, var = False):
|
||||
# Remove special ascii characters and process terminal cursor movements to clean logs.
|
||||
"""Remove special ascii characters and process terminal cursor movements to clean logs."""
|
||||
from .cli.terminal_ui import log_cleaner
|
||||
|
||||
if var == False:
|
||||
t = open(logfile, "r").read()
|
||||
try:
|
||||
with open(logfile, "r") as f:
|
||||
t = f.read()
|
||||
except:
|
||||
return
|
||||
else:
|
||||
t = logfile
|
||||
|
||||
lines = t.split('\n')
|
||||
cleaned_lines = []
|
||||
|
||||
# Regex to capture: ANSI sequences, control characters (\r, \b, etc), and plain text chunks
|
||||
token_re = re.compile(r'(\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/ ]*[@-~])|\r|\b|\x7f|[\x00-\x1F]|[^\x1B\r\b\x7f\x00-\x1F]+)')
|
||||
|
||||
for line in lines:
|
||||
buffer = []
|
||||
cursor = 0
|
||||
|
||||
for token in token_re.findall(line):
|
||||
if token == '\r':
|
||||
cursor = 0
|
||||
elif token in ('\b', '\x7f'):
|
||||
if cursor > 0:
|
||||
cursor -= 1
|
||||
elif token == '\x1B[D': # Left Arrow
|
||||
if cursor > 0:
|
||||
cursor -= 1
|
||||
elif token == '\x1B[C': # Right Arrow
|
||||
if cursor < len(buffer):
|
||||
cursor += 1
|
||||
elif token == '\x1B[K': # Clear to end of line
|
||||
buffer = buffer[:cursor]
|
||||
elif token.startswith('\x1B'):
|
||||
# Ignore other ANSI sequences (colors, etc)
|
||||
continue
|
||||
elif len(token) == 1 and ord(token) < 32:
|
||||
# Ignore other non-printable control chars
|
||||
continue
|
||||
else:
|
||||
# Regular printable text
|
||||
for char in token:
|
||||
if cursor == len(buffer):
|
||||
buffer.append(char)
|
||||
else:
|
||||
buffer[cursor] = char
|
||||
cursor += 1
|
||||
cleaned_lines.append("".join(buffer))
|
||||
|
||||
t = "\n".join(cleaned_lines).replace('\n\n', '\n').strip()
|
||||
result = log_cleaner(t)
|
||||
|
||||
if var == False:
|
||||
d = open(logfile, "w")
|
||||
d.write(t)
|
||||
d.close()
|
||||
try:
|
||||
with open(logfile, "w") as f:
|
||||
f.write(result)
|
||||
except:
|
||||
pass
|
||||
return
|
||||
else:
|
||||
return t
|
||||
return result
|
||||
|
||||
@MethodHook
|
||||
def _savelog(self):
|
||||
@@ -447,20 +415,22 @@ class node:
|
||||
|
||||
# Copilot interception
|
||||
if copilot_handler and b'\x00' in data:
|
||||
# Extract clean buffer from session log
|
||||
buffer = ""
|
||||
if hasattr(self, 'mylog'):
|
||||
raw = self.mylog.getvalue().decode(errors='replace')
|
||||
# Move heavy log cleaning to a thread
|
||||
buffer = await asyncio.to_thread(self._logclean, raw, True)
|
||||
# Build node info from available metadata and ensure values are strings (not bytes)
|
||||
def to_str(val):
|
||||
if isinstance(val, bytes):
|
||||
return val.decode(errors='replace')
|
||||
return str(val) if val is not None else "unknown"
|
||||
|
||||
# Build node info from available metadata
|
||||
node_info = {"name": getattr(self, 'unique', 'unknown'), "host": getattr(self, 'host', 'unknown')}
|
||||
node_info = {
|
||||
"name": to_str(getattr(self, 'unique', 'unknown')),
|
||||
"host": to_str(getattr(self, 'host', 'unknown'))
|
||||
}
|
||||
if isinstance(getattr(self, 'tags', None), dict):
|
||||
node_info["os"] = self.tags.get("os", "unknown")
|
||||
node_info["os"] = to_str(self.tags.get("os", "unknown"))
|
||||
node_info["prompt"] = to_str(self.tags.get("prompt", r'>$|#$|\$$|>.$|#.$|\$.$'))
|
||||
|
||||
# Invoke copilot (async callback handles UI)
|
||||
await copilot_handler(buffer, node_info, local_stream, child_fd, cmd_byte_positions)
|
||||
await copilot_handler(self.mylog.getvalue(), node_info, local_stream, child_fd, cmd_byte_positions)
|
||||
continue
|
||||
|
||||
# Remove any stray \x00 bytes and forward normally
|
||||
@@ -629,469 +599,83 @@ class node:
|
||||
def _build_local_copilot_handler(self):
|
||||
"""Build copilot handler for local CLI sessions using rich for rendering."""
|
||||
config = getattr(self, 'config', None) if hasattr(self, 'config') else None
|
||||
if not config:
|
||||
return None
|
||||
return self._copilot_handler(config)
|
||||
|
||||
# Persistent history across copilot invocations within the same session
|
||||
from prompt_toolkit.history import InMemoryHistory
|
||||
copilot_history = InMemoryHistory()
|
||||
def _copilot_handler(self, config):
|
||||
"""Unified copilot handler for local session."""
|
||||
from .cli.terminal_ui import CopilotInterface
|
||||
from .services.ai_service import AIService
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
async def handler(buffer, node_info, stream, child_fd, cmd_byte_positions=None):
|
||||
import termios, tty
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import fcntl
|
||||
|
||||
flags = 0
|
||||
stdin_fd = sys.stdin.fileno()
|
||||
|
||||
try:
|
||||
# Disable LocalStream reader so it doesn't steal keystrokes from Prompt
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.remove_reader(sys.stdin.fileno())
|
||||
interface = CopilotInterface(config, history=getattr(stream, 'copilot_history', None))
|
||||
# Save history back to stream for persistence in current session
|
||||
stream.copilot_history = interface.history
|
||||
|
||||
# 1. Salir de raw mode para poder usar input() y rich
|
||||
stdin_fd = sys.stdin.fileno()
|
||||
ai_service = AIService(config)
|
||||
|
||||
# Get true original settings saved before entering raw mode
|
||||
original_settings = getattr(stream, 'original_tty_settings', None)
|
||||
if original_settings:
|
||||
import copy
|
||||
new_settings = copy.deepcopy(original_settings)
|
||||
new_settings[3] = new_settings[3] & ~termios.ECHOCTL
|
||||
# CRITICAL: Prevent OS from translating Ctrl+C into SIGINT
|
||||
# This prevents the asyncio event loop from crashing when user hits Ctrl+C
|
||||
new_settings[3] = new_settings[3] & ~termios.ISIG
|
||||
termios.tcsetattr(stdin_fd, termios.TCSADRAIN, new_settings)
|
||||
async def on_ai_call(active_buffer, question, chunk_callback):
|
||||
return await ai_service.aask_copilot(
|
||||
active_buffer,
|
||||
question,
|
||||
node_info=node_info,
|
||||
chunk_callback=chunk_callback
|
||||
)
|
||||
|
||||
# Remove O_NONBLOCK from stdin so Prompt.ask() works
|
||||
import fcntl
|
||||
flags = fcntl.fcntl(stdin_fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(stdin_fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
|
||||
# Get raw bytes from BytesIO
|
||||
raw_bytes = self.mylog.getvalue()
|
||||
|
||||
# Force a carriage return so the UI doesn't start mid-line
|
||||
sys.stdout.write('\r\n')
|
||||
sys.stdout.flush()
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from rich.prompt import Prompt
|
||||
from .printer import connpy_theme
|
||||
|
||||
|
||||
console = Console(theme=connpy_theme)
|
||||
console.print("\n")
|
||||
console.print(Panel(
|
||||
"[bold cyan]AI Terminal Copilot[/bold cyan]\n"
|
||||
"[dim]Type your question. Enter to send, Escape/Ctrl+C to cancel.\n"
|
||||
"Tab to change context mode. Ctrl+\u2191/\u2193 to adjust context. \u2191\u2193 for question history.[/dim]",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
# 2. Capturar pregunta del usuario
|
||||
cancelled = [False]
|
||||
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
|
||||
bindings = KeyBindings()
|
||||
|
||||
# Command blocks logic
|
||||
raw_bytes = self.mylog.getvalue() if hasattr(self, 'mylog') else b''
|
||||
blocks = []
|
||||
|
||||
if cmd_byte_positions and len(cmd_byte_positions) >= 2 and raw_bytes:
|
||||
import re
|
||||
# Extract the prompt regex for validation
|
||||
default_prompt = r'>$|#$|\$$|>.$|#.$|\$.$'
|
||||
device_prompt = node_info.get("prompt", default_prompt) if isinstance(node_info, dict) else default_prompt
|
||||
# Remove unescaped $ end-anchors so we can match the prompt within the line
|
||||
prompt_re_str = re.sub(r'(?<!\\)\$', '', device_prompt)
|
||||
try:
|
||||
prompt_re = re.compile(prompt_re_str)
|
||||
except Exception:
|
||||
prompt_re = re.compile(re.sub(r'(?<!\\)\$', '', default_prompt))
|
||||
|
||||
for i in range(1, len(cmd_byte_positions)):
|
||||
pos, known_cmd = cmd_byte_positions[i]
|
||||
prev_pos = cmd_byte_positions[i-1][0]
|
||||
|
||||
if known_cmd:
|
||||
# AI-injected command: we already know the command text
|
||||
# Build preview from prompt (last line of previous chunk) + command
|
||||
prev_chunk = raw_bytes[prev_pos:pos]
|
||||
prev_cleaned = self._logclean(prev_chunk.decode(errors='replace'), var=True)
|
||||
prev_lines = [l for l in prev_cleaned.split('\n') if l.strip()]
|
||||
prompt_text = prev_lines[-1].strip() if prev_lines else ""
|
||||
preview = f"{prompt_text}{known_cmd}" if prompt_text else known_cmd
|
||||
blocks.append((pos, preview[:80]))
|
||||
else:
|
||||
# User-typed command: derive from raw log chunk
|
||||
chunk = raw_bytes[prev_pos:pos]
|
||||
cleaned = self._logclean(chunk.decode(errors='replace'), var=True)
|
||||
lines = [l for l in cleaned.split('\n') if l.strip()]
|
||||
preview = lines[-1].strip() if lines else ""
|
||||
|
||||
if preview:
|
||||
match = prompt_re.search(preview)
|
||||
if match:
|
||||
cmd_text = preview[match.end():].strip()
|
||||
# Only add if there is actual text typed (filters out empty enters and paginations)
|
||||
if cmd_text:
|
||||
blocks.append((pos, preview[:80]))
|
||||
|
||||
# Add synthetic "current prompt" block (zero context)
|
||||
last_line = buffer.split('\n')[-1].strip() if buffer.strip() else "(prompt)"
|
||||
blocks.append((len(raw_bytes), last_line[:80]))
|
||||
|
||||
context_cmd = [1]
|
||||
total_cmds = len(blocks)
|
||||
total_lines = len(buffer.split('\n'))
|
||||
context_lines = [min(50, total_lines)]
|
||||
# 0=range (cmd→END), 1=single (one cmd), 2=lines (adjustable by 50)
|
||||
context_mode = [0]
|
||||
MODE_RANGE, MODE_SINGLE, MODE_LINES = 0, 1, 2
|
||||
|
||||
@bindings.add('c-up')
|
||||
def _(event):
|
||||
if context_mode[0] == MODE_LINES:
|
||||
if context_lines[0] >= total_lines:
|
||||
context_lines[0] = min(50, total_lines)
|
||||
else:
|
||||
context_lines[0] = min(context_lines[0] + 50, total_lines)
|
||||
else:
|
||||
if context_cmd[0] < total_cmds:
|
||||
context_cmd[0] += 1
|
||||
else:
|
||||
context_cmd[0] = 1
|
||||
event.app.invalidate()
|
||||
|
||||
@bindings.add('c-down')
|
||||
def _(event):
|
||||
if context_mode[0] == MODE_LINES:
|
||||
if context_lines[0] <= min(50, total_lines):
|
||||
context_lines[0] = total_lines
|
||||
else:
|
||||
context_lines[0] = max(context_lines[0] - 50, min(50, total_lines))
|
||||
else:
|
||||
if context_cmd[0] > 1:
|
||||
context_cmd[0] -= 1
|
||||
else:
|
||||
context_cmd[0] = total_cmds
|
||||
event.app.invalidate()
|
||||
|
||||
@bindings.add('tab')
|
||||
def _(event):
|
||||
context_mode[0] = (context_mode[0] + 1) % 3
|
||||
event.app.invalidate()
|
||||
|
||||
@bindings.add('escape', eager=True)
|
||||
def _(event):
|
||||
cancelled[0] = True
|
||||
event.app.exit(result='')
|
||||
|
||||
def get_current_block():
|
||||
idx = max(0, total_cmds - context_cmd[0])
|
||||
return idx, blocks[idx]
|
||||
|
||||
def get_active_buffer():
|
||||
"""Build the active buffer for the current selection mode."""
|
||||
if context_mode[0] == MODE_LINES:
|
||||
buffer_lines = buffer.split('\n')
|
||||
return '\n'.join(buffer_lines[-context_lines[0]:])
|
||||
|
||||
idx, (start, preview) = get_current_block()
|
||||
if context_mode[0] == MODE_SINGLE and idx + 1 < total_cmds:
|
||||
end = blocks[idx + 1][0]
|
||||
active_raw = raw_bytes[start:end]
|
||||
else:
|
||||
active_raw = raw_bytes[start:]
|
||||
return preview + "\n" + self._logclean(active_raw.decode(errors='replace'), var=True)
|
||||
|
||||
def get_prompt_text():
|
||||
if context_mode[0] == MODE_LINES:
|
||||
return HTML(f"<ansicyan>Ask [Ctx: {context_lines[0]}/{total_lines}L]: </ansicyan>")
|
||||
|
||||
lines_count = len(get_active_buffer().split('\n'))
|
||||
if context_mode[0] == MODE_SINGLE:
|
||||
return HTML(f"<ansicyan>Ask [Cmd {context_cmd[0]} ~{lines_count}L]: </ansicyan>")
|
||||
else:
|
||||
return HTML(f"<ansicyan>Ask [Cmd {context_cmd[0]}\u2192END ~{lines_count}L]: </ansicyan>")
|
||||
|
||||
def get_toolbar():
|
||||
mode_labels = {MODE_RANGE: "RANGE", MODE_SINGLE: "SINGLE", MODE_LINES: "LINES"}
|
||||
mode_label = mode_labels[context_mode[0]]
|
||||
if context_mode[0] == MODE_LINES:
|
||||
return HTML(f"<ansigray>\u25b6 Ctrl+\u2191/\u2193 adjusts by 50 lines [Tab: {mode_label}]</ansigray>")
|
||||
_, (_, preview) = get_current_block()
|
||||
return HTML(f"<ansigray>\u25b6 {preview} [Tab: {mode_label}]</ansigray>")
|
||||
|
||||
import threading
|
||||
def preload_ai_deps():
|
||||
try:
|
||||
import litellm
|
||||
except Exception:
|
||||
pass
|
||||
threading.Thread(target=preload_ai_deps, daemon=True).start()
|
||||
# Detener el lector de la terminal para que prompt_toolkit (en run_session)
|
||||
# tenga control exclusivo del stdin sin interferencias de LocalStream.
|
||||
if hasattr(stream, 'stop_reading'):
|
||||
stream.stop_reading()
|
||||
elif hasattr(stream, '_loop') and hasattr(stream, 'stdin_fd'):
|
||||
# Fallback si no tiene el método (en LocalStream)
|
||||
stream._loop.remove_reader(stream.stdin_fd)
|
||||
|
||||
try:
|
||||
session = PromptSession(history=copilot_history)
|
||||
question = await session.prompt_async(
|
||||
get_prompt_text,
|
||||
key_bindings=bindings,
|
||||
bottom_toolbar=get_toolbar
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
question = ""
|
||||
|
||||
if cancelled[0] or not question.strip() or question.strip() == "CANCEL":
|
||||
console.print("\n[dim]Copilot cancelled.[/dim]")
|
||||
os.write(child_fd, b'\x15\r')
|
||||
return
|
||||
|
||||
active_buffer = get_active_buffer()
|
||||
|
||||
# 3. Llamar al AI con spinner
|
||||
from .services.ai_service import AIService
|
||||
service = AIService(config)
|
||||
|
||||
past_questions = copilot_history.get_strings()
|
||||
if len(past_questions) > 1:
|
||||
# Limit history to last 5 questions to save tokens, excluding current
|
||||
recent_history = past_questions[-6:-1]
|
||||
history_text = "\n".join(f"- {q}" for q in recent_history)
|
||||
enriched_question = f"Previous questions in this session:\n{history_text}\n\nCurrent Question:\n{question}"
|
||||
else:
|
||||
enriched_question = question
|
||||
|
||||
from rich.live import Live
|
||||
|
||||
live_text = "Thinking..."
|
||||
panel = Panel(live_text, title="[bold cyan]Copilot Guide[/bold cyan]", border_style="cyan")
|
||||
|
||||
def on_chunk(text):
|
||||
nonlocal live_text
|
||||
if live_text == "Thinking...":
|
||||
live_text = ""
|
||||
live_text += text
|
||||
try:
|
||||
# Use call_soon_threadsafe if possible, but rich Live is thread-safe enough
|
||||
loop.call_soon_threadsafe(
|
||||
lambda: live.update(Panel(Markdown(live_text), title="[bold cyan]Copilot Guide[/bold cyan]", border_style="cyan"))
|
||||
with copilot_terminal_mode():
|
||||
action, commands, custom_cmd = await interface.run_session(
|
||||
raw_bytes=raw_bytes,
|
||||
cmd_byte_positions=cmd_byte_positions,
|
||||
node_info=node_info,
|
||||
on_ai_call=on_ai_call
|
||||
)
|
||||
except Exception:
|
||||
live.update(Panel(Markdown(live_text), title="[bold cyan]Copilot Guide[/bold cyan]", border_style="cyan"))
|
||||
finally:
|
||||
# Reiniciar el lector de la terminal para volver al modo interactivo SSH/Telnet
|
||||
if hasattr(stream, 'start_reading'):
|
||||
stream.start_reading()
|
||||
elif hasattr(stream, '_loop') and hasattr(stream, 'stdin_fd'):
|
||||
stream._loop.add_reader(stream.stdin_fd, stream._read_ready)
|
||||
|
||||
with copilot_terminal_mode(), Live(panel, console=console, refresh_per_second=10) as live:
|
||||
# Launch the AI call as a task
|
||||
ai_task = asyncio.create_task(service.aask_copilot(active_buffer, enriched_question, node_info, chunk_callback=on_chunk))
|
||||
if action in ("send_all", "custom"):
|
||||
cmds_to_send = commands if action == "send_all" else custom_cmd
|
||||
|
||||
# Make stdin non-blocking
|
||||
import fcntl
|
||||
flags = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
|
||||
fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
|
||||
cancelled = False
|
||||
result = None
|
||||
|
||||
try:
|
||||
while not ai_task.done():
|
||||
try:
|
||||
key = os.read(sys.stdin.fileno(), 1024)
|
||||
if b'\x03' in key or b'\x1b' in key:
|
||||
cancelled = True
|
||||
ai_task.cancel()
|
||||
msg = "Ctrl+C" if b'\x03' in key else "Esc"
|
||||
console.print(f"\n[dim]Copilot cancelled via {msg}.[/dim]")
|
||||
break
|
||||
except OSError:
|
||||
pass
|
||||
# Yield to event loop to allow AI task to progress
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
if not cancelled:
|
||||
result = ai_task.result()
|
||||
except asyncio.CancelledError:
|
||||
cancelled = True
|
||||
console.print("\n[dim]Copilot cancelled.[/dim]")
|
||||
except KeyboardInterrupt:
|
||||
cancelled = True
|
||||
ai_task.cancel()
|
||||
console.print("\n[dim]Copilot cancelled via Ctrl+C.[/dim]")
|
||||
finally:
|
||||
# Restore stdin flags
|
||||
fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, flags)
|
||||
|
||||
if cancelled or not result:
|
||||
os.write(child_fd, b'\x15\r')
|
||||
return
|
||||
|
||||
if result.get("error"):
|
||||
console.print(f"[red]Error: {result['error']}[/red]")
|
||||
return
|
||||
|
||||
# If nothing was streamed (fallback), or to ensure final state
|
||||
if live_text == "Thinking..." and result.get("guide"):
|
||||
console.print(Panel(
|
||||
Markdown(result["guide"]),
|
||||
title="[bold cyan]Copilot Guide[/bold cyan]",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
commands = result.get("commands", [])
|
||||
risk = result.get("risk_level", "low")
|
||||
risk_style = {"low": "green", "high": "yellow", "destructive": "red"}.get(risk, "green")
|
||||
|
||||
if commands:
|
||||
cmd_text = "\n".join(f" {i+1}. {cmd}" for i, cmd in enumerate(commands))
|
||||
console.print(Panel(
|
||||
cmd_text,
|
||||
title=f"[bold {risk_style}]Suggested Commands [{risk.upper()}][/bold {risk_style}]",
|
||||
border_style=risk_style
|
||||
))
|
||||
|
||||
# 5. Preguntar si inyectar (usando prompt_toolkit)
|
||||
confirm_session = PromptSession()
|
||||
confirm_bindings = KeyBindings()
|
||||
|
||||
@confirm_bindings.add('escape', eager=True)
|
||||
def _(event):
|
||||
event.app.exit(result='n')
|
||||
|
||||
pt_color = "ansi" + risk_style
|
||||
try:
|
||||
action = await confirm_session.prompt_async(
|
||||
HTML(f"<{pt_color}>Send commands? (y/n/e/number/range) [n]: </{pt_color}>"),
|
||||
key_bindings=confirm_bindings
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
action = "n"
|
||||
|
||||
if not action.strip():
|
||||
action = "n"
|
||||
|
||||
console.print("[dim]Returning to session...[/dim]\n")
|
||||
|
||||
action_l = action.lower().strip()
|
||||
if action_l in ('y', 'yes', 'all'):
|
||||
os.write(child_fd, b'\x15') # Ctrl+U to clear line
|
||||
if cmds_to_send:
|
||||
os.write(child_fd, b'\x15') # Ctrl+U
|
||||
await asyncio.sleep(0.1)
|
||||
for cmd in commands:
|
||||
if cmd_byte_positions is not None and hasattr(self, 'mylog'):
|
||||
|
||||
# Prepend screen length command to avoid pagination
|
||||
if "screen_length_command" in self.tags:
|
||||
cmds_to_send.insert(0, self.tags["screen_length_command"])
|
||||
|
||||
for cmd in cmds_to_send:
|
||||
if cmd_byte_positions is not None:
|
||||
cmd_byte_positions.append((self.mylog.tell(), cmd))
|
||||
os.write(child_fd, (cmd + "\n").encode())
|
||||
await asyncio.sleep(0.3)
|
||||
elif action_l.startswith('e'):
|
||||
# Edit mode
|
||||
edit_session = PromptSession()
|
||||
cmds_to_edit = []
|
||||
|
||||
if len(action_l) > 1 and action_l[1:].isdigit():
|
||||
idx = int(action_l[1:]) - 1
|
||||
if 0 <= idx < len(commands):
|
||||
cmds_to_edit = [commands[idx]]
|
||||
else:
|
||||
cmds_to_edit = commands
|
||||
|
||||
if cmds_to_edit:
|
||||
target_cmd = "\n".join(cmds_to_edit)
|
||||
edit_bindings = KeyBindings()
|
||||
@edit_bindings.add('c-j')
|
||||
def _(event):
|
||||
event.app.exit(result=event.app.current_buffer.text)
|
||||
|
||||
@edit_bindings.add('escape', eager=True)
|
||||
def _(event):
|
||||
event.app.exit(result='')
|
||||
|
||||
try:
|
||||
edited_cmd = await edit_session.prompt_async(
|
||||
HTML("<ansicyan>Edit commands (Ctrl+Enter to submit, Esc to cancel):\n</ansicyan>"),
|
||||
default=target_cmd,
|
||||
multiline=True,
|
||||
key_bindings=edit_bindings
|
||||
)
|
||||
if edited_cmd.strip():
|
||||
os.write(child_fd, b'\x15')
|
||||
await asyncio.sleep(0.1)
|
||||
for cmd in edited_cmd.split('\n'):
|
||||
if cmd.strip():
|
||||
if cmd_byte_positions is not None and hasattr(self, 'mylog'):
|
||||
cmd_byte_positions.append((self.mylog.tell(), cmd.strip()))
|
||||
os.write(child_fd, (cmd.strip() + "\n").encode())
|
||||
await asyncio.sleep(0.3)
|
||||
else:
|
||||
os.write(child_fd, b'\x15\r')
|
||||
except KeyboardInterrupt:
|
||||
os.write(child_fd, b'\x15\r')
|
||||
else:
|
||||
os.write(child_fd, b'\x15\r')
|
||||
elif action_l not in ('n', 'no', ''):
|
||||
try:
|
||||
selected_indices = set()
|
||||
for part in action_l.split(','):
|
||||
part = part.strip()
|
||||
if not part: continue
|
||||
if '-' in part:
|
||||
start_str, end_str = part.split('-', 1)
|
||||
start = int(start_str) - 1
|
||||
end = int(end_str) - 1
|
||||
for i in range(start, end + 1):
|
||||
selected_indices.add(i)
|
||||
else:
|
||||
selected_indices.add(int(part) - 1)
|
||||
|
||||
valid_indices = sorted([i for i in selected_indices if 0 <= i < len(commands)])
|
||||
if valid_indices:
|
||||
os.write(child_fd, b'\x15') # Ctrl+U to clear line
|
||||
await asyncio.sleep(0.1)
|
||||
if len(valid_indices) == 1:
|
||||
if cmd_byte_positions is not None and hasattr(self, 'mylog'):
|
||||
cmd_byte_positions.append((self.mylog.tell(), commands[valid_indices[0]]))
|
||||
os.write(child_fd, (commands[valid_indices[0]] + "\n").encode())
|
||||
else:
|
||||
for idx in valid_indices:
|
||||
if cmd_byte_positions is not None and hasattr(self, 'mylog'):
|
||||
cmd_byte_positions.append((self.mylog.tell(), commands[idx]))
|
||||
os.write(child_fd, (commands[idx] + "\n").encode())
|
||||
await asyncio.sleep(0.3)
|
||||
else:
|
||||
os.write(child_fd, b'\x15\r') # Ctrl+U + Enter to abort line and get new prompt
|
||||
except ValueError:
|
||||
os.write(child_fd, b'\x15\r')
|
||||
else:
|
||||
os.write(child_fd, b'\x15\r')
|
||||
await asyncio.sleep(0.8)
|
||||
else:
|
||||
console.print("[dim]Returning to session...[/dim]\n")
|
||||
os.write(child_fd, b'\x15\r')
|
||||
except KeyboardInterrupt:
|
||||
if 'console' in locals():
|
||||
console.print("\n[dim]Copilot cancelled via Ctrl+C.[/dim]\n")
|
||||
else:
|
||||
print("\n[dim]Copilot cancelled via Ctrl+C.[/dim]\n")
|
||||
os.write(child_fd, b'\x15\r')
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"\n[ERROR in Copilot Handler] {e}", flush=True)
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# 6. Restaurar raw mode, O_NONBLOCK y SIGINT
|
||||
tty.setraw(stdin_fd)
|
||||
fcntl.fcntl(stdin_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
|
||||
# Re-enable LocalStream reader
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.add_reader(stdin_fd, stream._read_ready)
|
||||
except Exception:
|
||||
pass
|
||||
os.write(child_fd, b'\x15\r')
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
@MethodHook
|
||||
def run(self, commands, vars = None,*, folder = '', prompt = r'>$|#$|\$$|>.$|#.$|\$.$', stdout = False, timeout = 10, logger = None):
|
||||
'''
|
||||
@@ -1152,7 +736,6 @@ class node:
|
||||
if "prompt" in self.tags:
|
||||
prompt = self.tags["prompt"]
|
||||
expects = [prompt, pexpect.EOF, pexpect.TIMEOUT]
|
||||
|
||||
output = ''
|
||||
status = ''
|
||||
if not isinstance(commands, list):
|
||||
@@ -1263,7 +846,6 @@ class node:
|
||||
if "prompt" in self.tags:
|
||||
prompt = self.tags["prompt"]
|
||||
expects = [prompt, pexpect.EOF, pexpect.TIMEOUT]
|
||||
|
||||
output = ''
|
||||
if not isinstance(commands, list):
|
||||
commands = [commands]
|
||||
@@ -1329,8 +911,6 @@ class node:
|
||||
@MethodHook
|
||||
def _generate_ssh_sftp_cmd(self):
|
||||
cmd = self.protocol
|
||||
if self.idletime > 0:
|
||||
cmd += " -o ServerAliveInterval=" + str(self.idletime)
|
||||
if self.port:
|
||||
if self.protocol == "ssh":
|
||||
cmd += " -p " + self.port
|
||||
@@ -1385,6 +965,19 @@ class node:
|
||||
cmd += f" {self.options}"
|
||||
return cmd
|
||||
|
||||
@MethodHook
|
||||
def _generate_ssm_cmd(self):
|
||||
region = self.tags.get("region", "") if isinstance(self.tags, dict) else ""
|
||||
profile = self.tags.get("profile", "") if isinstance(self.tags, dict) else ""
|
||||
cmd = f"aws ssm start-session --target {self.host}"
|
||||
if region:
|
||||
cmd += f" --region {region}"
|
||||
if profile:
|
||||
cmd += f" --profile {profile}"
|
||||
if self.options:
|
||||
cmd += f" {self.options}"
|
||||
return cmd
|
||||
|
||||
@MethodHook
|
||||
def _get_cmd(self):
|
||||
if self.protocol in ["ssh", "sftp"]:
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
+40
-48
@@ -207,59 +207,19 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
||||
import json
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
|
||||
# Build context blocks like local CLI does
|
||||
blocks = []
|
||||
raw_bytes = n.mylog.getvalue() if hasattr(n, 'mylog') else b''
|
||||
|
||||
if cmd_byte_positions and len(cmd_byte_positions) >= 2 and raw_bytes:
|
||||
default_prompt = r'>$|#$|\$$|>.$|#.$|\$.$'
|
||||
device_prompt = node_info.get("prompt", default_prompt) if isinstance(node_info, dict) else default_prompt
|
||||
prompt_re_str = re.sub(r'(?<!\\)\$', '', device_prompt)
|
||||
try:
|
||||
prompt_re = re.compile(prompt_re_str)
|
||||
except Exception:
|
||||
prompt_re = re.compile(re.sub(r'(?<!\\)\$', '', default_prompt))
|
||||
|
||||
for i in range(1, len(cmd_byte_positions)):
|
||||
pos, known_cmd = cmd_byte_positions[i]
|
||||
prev_pos = cmd_byte_positions[i-1][0]
|
||||
|
||||
if known_cmd:
|
||||
prev_chunk = raw_bytes[prev_pos:pos]
|
||||
prev_cleaned = n._logclean(prev_chunk.decode(errors='replace'), var=True)
|
||||
prev_lines = [l for l in prev_cleaned.split('\n') if l.strip()]
|
||||
prompt_text = prev_lines[-1].strip() if prev_lines else ""
|
||||
preview = f"{prompt_text}{known_cmd}" if prompt_text else known_cmd
|
||||
blocks.append({"pos": pos, "preview": preview[:80], "type": "cmd"})
|
||||
else:
|
||||
chunk = raw_bytes[prev_pos:pos]
|
||||
cleaned = n._logclean(chunk.decode(errors='replace'), var=True)
|
||||
lines = [l for l in cleaned.split('\n') if l.strip()]
|
||||
preview = lines[-1].strip() if lines else ""
|
||||
|
||||
if preview:
|
||||
match = prompt_re.search(preview)
|
||||
if match:
|
||||
cmd_text = preview[match.end():].strip()
|
||||
if cmd_text:
|
||||
blocks.append({"pos": pos, "preview": preview[:80], "type": "cmd"})
|
||||
|
||||
clean_buffer = n._logclean(raw_bytes.decode(errors='replace'), var=True)
|
||||
last_line = clean_buffer.split('\n')[-1].strip() if clean_buffer.strip() else "(prompt)"
|
||||
blocks.append({"pos": len(raw_bytes), "preview": last_line[:80], "type": "current"})
|
||||
|
||||
if node_info is None:
|
||||
node_info = {}
|
||||
node_info["context_blocks"] = blocks
|
||||
node_info["full_buffer"] = buffer
|
||||
|
||||
node_info_json = json.dumps(node_info)
|
||||
|
||||
# Convert buffer to string if it's bytes for the preview
|
||||
preview_str = buffer[-200:].decode(errors='replace') if isinstance(buffer, bytes) else str(buffer)[-200:]
|
||||
|
||||
# 1. Send prompt to client
|
||||
response_queue.put(connpy_pb2.InteractResponse(
|
||||
copilot_prompt=True,
|
||||
copilot_buffer_preview=buffer[-200:],
|
||||
copilot_buffer_preview=preview_str,
|
||||
copilot_node_info_json=node_info_json
|
||||
))
|
||||
|
||||
@@ -349,19 +309,33 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
||||
commands = result.get("commands", [])
|
||||
os.write(child_fd, b'\x15') # Ctrl+U to clear line
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Prepend screen length command to avoid pagination
|
||||
if "screen_length_command" in n.tags:
|
||||
os.write(child_fd, (n.tags["screen_length_command"] + "\n").encode())
|
||||
response_queue.put(connpy_pb2.InteractResponse(copilot_injected_command=n.tags["screen_length_command"]))
|
||||
await asyncio.sleep(0.8)
|
||||
|
||||
for cmd in commands:
|
||||
os.write(child_fd, (cmd + "\n").encode())
|
||||
response_queue.put(connpy_pb2.InteractResponse(copilot_injected_command=cmd))
|
||||
await asyncio.sleep(0.3)
|
||||
await asyncio.sleep(0.8)
|
||||
elif action.startswith("custom:"):
|
||||
custom_cmds = action[7:]
|
||||
os.write(child_fd, b'\x15')
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Prepend screen length command to avoid pagination
|
||||
if "screen_length_command" in n.tags:
|
||||
os.write(child_fd, (n.tags["screen_length_command"] + "\n").encode())
|
||||
response_queue.put(connpy_pb2.InteractResponse(copilot_injected_command=n.tags["screen_length_command"]))
|
||||
await asyncio.sleep(0.8)
|
||||
|
||||
for cmd in custom_cmds.split('\n'):
|
||||
if cmd.strip():
|
||||
os.write(child_fd, (cmd.strip() + "\n").encode())
|
||||
response_queue.put(connpy_pb2.InteractResponse(copilot_injected_command=cmd.strip()))
|
||||
await asyncio.sleep(0.3)
|
||||
await asyncio.sleep(0.8)
|
||||
elif action not in ('cancel', 'n', 'no'):
|
||||
# Handle numbers and ranges like "1,2,4-6"
|
||||
try:
|
||||
@@ -383,10 +357,17 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
||||
if valid_indices:
|
||||
os.write(child_fd, b'\x15')
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Prepend screen length command to avoid pagination
|
||||
if "screen_length_command" in n.tags:
|
||||
os.write(child_fd, (n.tags["screen_length_command"] + "\n").encode())
|
||||
response_queue.put(connpy_pb2.InteractResponse(copilot_injected_command=n.tags["screen_length_command"]))
|
||||
await asyncio.sleep(0.8)
|
||||
|
||||
for idx in valid_indices:
|
||||
os.write(child_fd, (commands[idx] + "\n").encode())
|
||||
response_queue.put(connpy_pb2.InteractResponse(copilot_injected_command=commands[idx]))
|
||||
await asyncio.sleep(0.3)
|
||||
await asyncio.sleep(0.8)
|
||||
else:
|
||||
os.write(child_fd, b'\x15\r')
|
||||
except (ValueError, IndexError):
|
||||
@@ -980,6 +961,17 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
|
||||
self.service.configure_provider(request.provider, request.model, request.api_key)
|
||||
return Empty()
|
||||
|
||||
@handle_errors
|
||||
def configure_mcp(self, request, context):
|
||||
self.service.configure_mcp(
|
||||
request.name,
|
||||
url=request.url or None,
|
||||
enabled=request.enabled,
|
||||
auto_load_on_os=request.auto_load_on_os or None,
|
||||
remove=request.remove
|
||||
)
|
||||
return Empty()
|
||||
|
||||
@handle_errors
|
||||
def load_session_data(self, request, context):
|
||||
return connpy_pb2.StructResponse(data=to_struct(self.service.load_session_data(request.value)))
|
||||
|
||||
+76
-634
@@ -9,6 +9,7 @@ from .utils import to_value, from_value, to_struct, from_struct
|
||||
from ..services.exceptions import ConnpyError
|
||||
from ..hooks import MethodHook
|
||||
from .. import printer
|
||||
from ..cli.terminal_ui import log_cleaner, CopilotInterface
|
||||
|
||||
def handle_errors(func):
|
||||
@wraps(func)
|
||||
@@ -41,6 +42,60 @@ class NodeStub:
|
||||
self.remote_host = remote_host
|
||||
self.config = config
|
||||
|
||||
def _handle_remote_copilot(self, res, request_queue, response_queue, client_buffer_bytes, cmd_byte_positions, pause_generator, resume_generator, old_tty):
|
||||
import json, asyncio, termios, sys, tty, queue
|
||||
from ..core import copilot_terminal_mode
|
||||
from . import connpy_pb2
|
||||
|
||||
pause_generator()
|
||||
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
interface = CopilotInterface(self.config, history=getattr(self, 'copilot_history', None))
|
||||
self.copilot_history = interface.history
|
||||
|
||||
node_info = json.loads(res.copilot_node_info_json) if res.copilot_node_info_json else {}
|
||||
|
||||
async def on_ai_call_remote(active_buffer, question, chunk_callback):
|
||||
# Send request to server
|
||||
request_queue.put(connpy_pb2.InteractRequest(
|
||||
copilot_question=question,
|
||||
copilot_context_buffer=active_buffer
|
||||
))
|
||||
# Wait for chunks from server
|
||||
while True:
|
||||
try:
|
||||
chunk_res = response_queue.get(timeout=0.1)
|
||||
if chunk_res is None: return {"error": "Server disconnected"}
|
||||
if chunk_res.copilot_stream_chunk:
|
||||
chunk_callback(chunk_res.copilot_stream_chunk)
|
||||
elif chunk_res.copilot_response_json:
|
||||
return json.loads(chunk_res.copilot_response_json)
|
||||
except queue.Empty:
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Wrap in async loop
|
||||
async def run_remote_copilot():
|
||||
return await interface.run_session(
|
||||
raw_bytes=bytes(client_buffer_bytes),
|
||||
cmd_byte_positions=cmd_byte_positions,
|
||||
node_info=node_info,
|
||||
on_ai_call=on_ai_call_remote
|
||||
)
|
||||
|
||||
with copilot_terminal_mode():
|
||||
action, commands, custom_cmd = asyncio.run(run_remote_copilot())
|
||||
|
||||
# Prepare final action for server
|
||||
action_sent = "cancel"
|
||||
if action == "send_all":
|
||||
action_sent = "send_all"
|
||||
elif action == "custom" and custom_cmd:
|
||||
action_sent = f"custom:{chr(10).join(custom_cmd)}"
|
||||
|
||||
request_queue.put(connpy_pb2.InteractRequest(copilot_action=action_sent))
|
||||
resume_generator()
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
|
||||
@handle_errors
|
||||
def connect_node(self, unique_id, sftp=False, debug=False, logger=None):
|
||||
import sys
|
||||
@@ -173,325 +228,11 @@ class NodeStub:
|
||||
if res is None:
|
||||
break
|
||||
if res.copilot_prompt:
|
||||
pause_generator()
|
||||
import json
|
||||
import asyncio
|
||||
import re
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
from prompt_toolkit.history import InMemoryHistory
|
||||
from ..printer import connpy_theme
|
||||
from ..core import copilot_terminal_mode
|
||||
|
||||
if not hasattr(self, 'copilot_history'):
|
||||
self.copilot_history = InMemoryHistory()
|
||||
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
import fcntl
|
||||
flags = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
|
||||
fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
|
||||
console = Console(theme=connpy_theme)
|
||||
console.print("\n")
|
||||
console.print(Panel(
|
||||
"[bold cyan]AI Terminal Copilot[/bold cyan]\n"
|
||||
"[dim]Type your question. Enter to send, Escape/Ctrl+C to cancel.\n"
|
||||
"Tab to change context mode. Ctrl+\u2191/\u2193 to adjust context. \u2191\u2193 for question history.[/dim]",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
node_info = json.loads(res.copilot_node_info_json) if res.copilot_node_info_json else {}
|
||||
|
||||
# Logic for context selection
|
||||
blocks = []
|
||||
raw_bytes = client_buffer_bytes
|
||||
from ..core import node
|
||||
dummy_node = node("dummy", "dummy") # For logclean
|
||||
|
||||
if cmd_byte_positions and len(cmd_byte_positions) >= 2 and raw_bytes:
|
||||
default_prompt = r'>$|#$|\$$|>.$|#.$|\$.$'
|
||||
device_prompt = node_info.get("prompt", default_prompt)
|
||||
prompt_re_str = re.sub(r'(?<!\\)\$', '', device_prompt)
|
||||
try:
|
||||
prompt_re = re.compile(prompt_re_str)
|
||||
except Exception:
|
||||
prompt_re = re.compile(re.sub(r'(?<!\\)\$', '', default_prompt))
|
||||
|
||||
for i in range(1, len(cmd_byte_positions)):
|
||||
pos, known_cmd = cmd_byte_positions[i]
|
||||
prev_pos = cmd_byte_positions[i-1][0]
|
||||
|
||||
if known_cmd:
|
||||
# AI-injected command: we already know the command text
|
||||
prev_chunk = raw_bytes[prev_pos:pos]
|
||||
prev_cleaned = dummy_node._logclean(prev_chunk.decode(errors='replace'), var=True)
|
||||
prev_lines = [l for l in prev_cleaned.split('\n') if l.strip()]
|
||||
prompt_text = prev_lines[-1].strip() if prev_lines else ""
|
||||
preview = f"{prompt_text}{known_cmd}" if prompt_text else known_cmd
|
||||
blocks.append((pos, preview[:80]))
|
||||
else:
|
||||
# User-typed command: derive from raw log chunk
|
||||
chunk = raw_bytes[prev_pos:pos]
|
||||
cleaned = dummy_node._logclean(chunk.decode(errors='replace'), var=True)
|
||||
lines = [l for l in cleaned.split('\n') if l.strip()]
|
||||
preview = lines[-1].strip() if lines else ""
|
||||
|
||||
if preview:
|
||||
match = prompt_re.search(preview)
|
||||
if match:
|
||||
cmd_text = preview[match.end():].strip()
|
||||
if cmd_text:
|
||||
blocks.append((pos, preview[:80]))
|
||||
|
||||
clean_buffer = dummy_node._logclean(raw_bytes.decode(errors='replace'), var=True)
|
||||
last_line = clean_buffer.split('\n')[-1].strip() if clean_buffer.strip() else "(prompt)"
|
||||
blocks.append((len(raw_bytes), last_line[:80]))
|
||||
|
||||
context_cmd = [1]
|
||||
total_cmds = len(blocks)
|
||||
total_lines = len(clean_buffer.split('\n'))
|
||||
context_lines = [min(50, total_lines)]
|
||||
context_mode = [0]
|
||||
MODE_RANGE, MODE_SINGLE, MODE_LINES = 0, 1, 2
|
||||
|
||||
bindings = KeyBindings()
|
||||
|
||||
@bindings.add('c-up')
|
||||
def _(event):
|
||||
if context_mode[0] == MODE_LINES:
|
||||
if context_lines[0] >= total_lines:
|
||||
context_lines[0] = min(50, total_lines)
|
||||
else:
|
||||
context_lines[0] = min(context_lines[0] + 50, total_lines)
|
||||
else:
|
||||
if context_cmd[0] < total_cmds:
|
||||
context_cmd[0] += 1
|
||||
else:
|
||||
context_cmd[0] = 1
|
||||
event.app.invalidate()
|
||||
|
||||
@bindings.add('c-down')
|
||||
def _(event):
|
||||
if context_mode[0] == MODE_LINES:
|
||||
if context_lines[0] <= min(50, total_lines):
|
||||
context_lines[0] = total_lines
|
||||
else:
|
||||
context_lines[0] = max(context_lines[0] - 50, min(50, total_lines))
|
||||
else:
|
||||
if context_cmd[0] > 1:
|
||||
context_cmd[0] -= 1
|
||||
else:
|
||||
context_cmd[0] = total_cmds
|
||||
event.app.invalidate()
|
||||
|
||||
@bindings.add('tab')
|
||||
def _(event):
|
||||
context_mode[0] = (context_mode[0] + 1) % 3
|
||||
event.app.invalidate()
|
||||
|
||||
@bindings.add('escape', eager=True)
|
||||
def _(event):
|
||||
event.app.exit(result='')
|
||||
|
||||
def get_current_block():
|
||||
idx = max(0, total_cmds - context_cmd[0])
|
||||
return idx, blocks[idx]
|
||||
|
||||
def get_active_buffer():
|
||||
if context_mode[0] == MODE_LINES:
|
||||
buffer_lines = clean_buffer.split('\n')
|
||||
return '\n'.join(buffer_lines[-context_lines[0]:])
|
||||
|
||||
idx, (start, preview) = get_current_block()
|
||||
if context_mode[0] == MODE_SINGLE and idx + 1 < total_cmds:
|
||||
end = blocks[idx + 1][0]
|
||||
active_raw = raw_bytes[start:end]
|
||||
else:
|
||||
active_raw = raw_bytes[start:]
|
||||
return preview + "\n" + dummy_node._logclean(active_raw.decode(errors='replace'), var=True)
|
||||
|
||||
def get_prompt_text():
|
||||
if context_mode[0] == MODE_LINES:
|
||||
return HTML(f"<ansicyan>Ask [Ctx: {context_lines[0]}/{total_lines}L]: </ansicyan>")
|
||||
|
||||
lines_count = len(get_active_buffer().split('\n'))
|
||||
if context_mode[0] == MODE_SINGLE:
|
||||
return HTML(f"<ansicyan>Ask [Cmd {context_cmd[0]} ~{lines_count}L]: </ansicyan>")
|
||||
else:
|
||||
return HTML(f"<ansicyan>Ask [Cmd {context_cmd[0]}\u2192END ~{lines_count}L]: </ansicyan>")
|
||||
|
||||
def get_toolbar():
|
||||
mode_labels = {MODE_RANGE: "RANGE", MODE_SINGLE: "SINGLE", MODE_LINES: "LINES"}
|
||||
mode_label = mode_labels[context_mode[0]]
|
||||
if context_mode[0] == MODE_LINES:
|
||||
return HTML(f"<ansigray>\u25b6 Ctrl+\u2191/\u2193 adjusts by 50 lines [Tab: {mode_label}]</ansigray>")
|
||||
_, (_, preview) = get_current_block()
|
||||
return HTML(f"<ansigray>\u25b6 {preview} [Tab: {mode_label}]</ansigray>")
|
||||
|
||||
try:
|
||||
session = PromptSession(history=self.copilot_history)
|
||||
question = session.prompt(get_prompt_text, key_bindings=bindings, bottom_toolbar=get_toolbar)
|
||||
except KeyboardInterrupt:
|
||||
question = ""
|
||||
|
||||
if not question or not question.strip() or question.strip() == "CANCEL":
|
||||
console.print("\n[dim]Copilot cancelled.[/dim]")
|
||||
request_queue.put(connpy_pb2.InteractRequest(copilot_question="CANCEL"))
|
||||
resume_generator()
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
continue
|
||||
|
||||
active_buffer = get_active_buffer()
|
||||
# Enrich question with history (same as local CLI)
|
||||
past_questions = self.copilot_history.get_strings()
|
||||
if len(past_questions) > 1:
|
||||
# Limit history to last 5 questions to save tokens, excluding current
|
||||
recent_history = past_questions[-6:-1]
|
||||
history_text = "\n".join(f"- {q}" for q in recent_history)
|
||||
enriched_question = f"Previous questions in this session:\n{history_text}\n\nCurrent Question:\n{question}"
|
||||
else:
|
||||
enriched_question = question
|
||||
|
||||
request_queue.put(connpy_pb2.InteractRequest(copilot_question=enriched_question, copilot_context_buffer=active_buffer))
|
||||
|
||||
from rich.live import Live
|
||||
live_text = "Thinking..."
|
||||
panel = Panel(live_text, title="[bold cyan]Copilot Guide[/bold cyan]", border_style="cyan")
|
||||
result = {}
|
||||
cancelled = False
|
||||
|
||||
with copilot_terminal_mode(), Live(panel, console=console, refresh_per_second=10) as live:
|
||||
# Make stdin non-blocking to check for Ctrl+C locally
|
||||
import fcntl
|
||||
flags = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
|
||||
fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
|
||||
while True:
|
||||
# 1. Read input for Ctrl+C
|
||||
try:
|
||||
key = os.read(sys.stdin.fileno(), 1024)
|
||||
if b'\x03' in key or b'\x1b' in key:
|
||||
cancelled = True
|
||||
request_queue.put(connpy_pb2.InteractRequest(copilot_question="CANCEL"))
|
||||
msg = "Ctrl+C" if b'\x03' in key else "Esc"
|
||||
console.print(f"\n[dim]Copilot cancelled via {msg}.[/dim]")
|
||||
break
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# 2. Wait for response chunk
|
||||
try:
|
||||
chunk_res = response_queue.get(timeout=0.1)
|
||||
if chunk_res is None:
|
||||
break
|
||||
|
||||
if chunk_res.copilot_stream_chunk:
|
||||
if live_text == "Thinking...": live_text = ""
|
||||
live_text += chunk_res.copilot_stream_chunk
|
||||
live.update(Panel(Markdown(live_text), title="[bold cyan]Copilot Guide[/bold cyan]", border_style="cyan"))
|
||||
elif chunk_res.copilot_response_json:
|
||||
result = json.loads(chunk_res.copilot_response_json)
|
||||
break
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# Restore blocking mode
|
||||
fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, flags)
|
||||
|
||||
if cancelled:
|
||||
resume_generator()
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
continue
|
||||
|
||||
if result.get("error"):
|
||||
console.print(f"[red]Error: {result['error']}[/red]")
|
||||
request_queue.put(connpy_pb2.InteractRequest(copilot_action="cancel"))
|
||||
resume_generator()
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
continue
|
||||
|
||||
if live_text == "Thinking..." and result.get("guide"):
|
||||
console.print(Panel(Markdown(result["guide"]), title="[bold cyan]Copilot Guide[/bold cyan]", border_style="cyan"))
|
||||
|
||||
commands = result.get("commands", [])
|
||||
risk = result.get("risk_level", "low")
|
||||
risk_style = {"low": "green", "high": "yellow", "destructive": "red"}.get(risk, "green")
|
||||
|
||||
action_sent = "cancel"
|
||||
if commands:
|
||||
cmd_text = "\n".join(f" {i+1}. {cmd}" for i, cmd in enumerate(commands))
|
||||
console.print(Panel(
|
||||
cmd_text,
|
||||
title=f"[bold {risk_style}]Suggested Commands [{risk.upper()}][/bold {risk_style}]",
|
||||
border_style=risk_style
|
||||
))
|
||||
|
||||
try:
|
||||
confirm_session = PromptSession()
|
||||
confirm_bindings = KeyBindings()
|
||||
@confirm_bindings.add('escape', eager=True)
|
||||
def _(event):
|
||||
event.app.exit(result='n')
|
||||
|
||||
pt_color = "ansi" + risk_style
|
||||
action = confirm_session.prompt(
|
||||
HTML(f"<{pt_color}>Send commands? (y/n/e/number/range) [n]: </{pt_color}>"),
|
||||
key_bindings=confirm_bindings
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
action = "n"
|
||||
|
||||
if not action.strip():
|
||||
action = "n"
|
||||
|
||||
action_l = action.lower().strip()
|
||||
if action_l in ('y', 'yes', 'all'):
|
||||
action_sent = "send_all"
|
||||
elif action_l.startswith('e'):
|
||||
action_sent = f"edit_{action_l[1:]}" if len(action_l) > 1 else "edit_all"
|
||||
# For remote editing, the client edits and sends back as custom action
|
||||
edit_session = PromptSession()
|
||||
cmds_to_edit = []
|
||||
if action_sent.startswith("edit_") and action_sent[5:].isdigit():
|
||||
idx = int(action_sent[5:]) - 1
|
||||
if 0 <= idx < len(commands):
|
||||
cmds_to_edit = [commands[idx]]
|
||||
else:
|
||||
cmds_to_edit = commands
|
||||
|
||||
if cmds_to_edit:
|
||||
target_cmd = "\n".join(cmds_to_edit)
|
||||
try:
|
||||
edit_bindings = KeyBindings()
|
||||
@edit_bindings.add('c-j')
|
||||
def _(event):
|
||||
event.app.exit(result=event.app.current_buffer.text)
|
||||
@edit_bindings.add('escape', eager=True)
|
||||
def _(event):
|
||||
event.app.exit(result='')
|
||||
|
||||
edited_cmd = edit_session.prompt(
|
||||
HTML("<ansicyan>Edit commands (Ctrl+Enter to submit, Esc to cancel):\n</ansicyan>"),
|
||||
default=target_cmd,
|
||||
multiline=True,
|
||||
key_bindings=edit_bindings
|
||||
)
|
||||
if edited_cmd.strip():
|
||||
action_sent = "custom:" + edited_cmd.strip()
|
||||
else:
|
||||
action_sent = "cancel"
|
||||
except KeyboardInterrupt:
|
||||
action_sent = "cancel"
|
||||
elif action_l not in ('n', 'no', ''):
|
||||
action_sent = action_l
|
||||
|
||||
console.print("[dim]Returning to session...[/dim]\n")
|
||||
request_queue.put(connpy_pb2.InteractRequest(copilot_action=action_sent))
|
||||
resume_generator()
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
self._handle_remote_copilot(
|
||||
res, request_queue, response_queue,
|
||||
client_buffer_bytes, cmd_byte_positions,
|
||||
pause_generator, resume_generator, old_tty
|
||||
)
|
||||
continue
|
||||
|
||||
if res.copilot_injected_command:
|
||||
@@ -638,321 +379,11 @@ class NodeStub:
|
||||
if res is None:
|
||||
break
|
||||
if res.copilot_prompt:
|
||||
pause_generator()
|
||||
import json
|
||||
import asyncio
|
||||
import re
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
from prompt_toolkit.history import InMemoryHistory
|
||||
from ..printer import connpy_theme
|
||||
from ..core import copilot_terminal_mode
|
||||
|
||||
if not hasattr(self, 'copilot_history'):
|
||||
self.copilot_history = InMemoryHistory()
|
||||
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
import fcntl
|
||||
flags = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
|
||||
fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
|
||||
console = Console(theme=connpy_theme)
|
||||
console.print("\n")
|
||||
console.print(Panel(
|
||||
"[bold cyan]AI Terminal Copilot[/bold cyan]\n"
|
||||
"[dim]Type your question. Enter to send, Escape/Ctrl+C to cancel.\n"
|
||||
"Tab to change context mode. Ctrl+\u2191/\u2193 to adjust context. \u2191\u2193 for question history.[/dim]",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
node_info = json.loads(res.copilot_node_info_json) if res.copilot_node_info_json else {}
|
||||
|
||||
# Logic for context selection
|
||||
blocks = []
|
||||
raw_bytes = client_buffer_bytes
|
||||
from ..core import node
|
||||
dummy_node = node("dummy", "dummy") # For logclean
|
||||
|
||||
if cmd_byte_positions and len(cmd_byte_positions) >= 2 and raw_bytes:
|
||||
default_prompt = r'>$|#$|\$$|>.$|#.$|\$.$'
|
||||
device_prompt = node_info.get("prompt", default_prompt)
|
||||
prompt_re_str = re.sub(r'(?<!\\)\$', '', device_prompt)
|
||||
try:
|
||||
prompt_re = re.compile(prompt_re_str)
|
||||
except Exception:
|
||||
prompt_re = re.compile(re.sub(r'(?<!\\)\$', '', default_prompt))
|
||||
|
||||
for i in range(1, len(cmd_byte_positions)):
|
||||
pos, known_cmd = cmd_byte_positions[i]
|
||||
prev_pos = cmd_byte_positions[i-1][0]
|
||||
|
||||
if known_cmd:
|
||||
# AI-injected command: we already know the command text
|
||||
prev_chunk = raw_bytes[prev_pos:pos]
|
||||
prev_cleaned = dummy_node._logclean(prev_chunk.decode(errors='replace'), var=True)
|
||||
prev_lines = [l for l in prev_cleaned.split('\n') if l.strip()]
|
||||
prompt_text = prev_lines[-1].strip() if prev_lines else ""
|
||||
preview = f"{prompt_text}{known_cmd}" if prompt_text else known_cmd
|
||||
blocks.append((pos, preview[:80]))
|
||||
else:
|
||||
# User-typed command: derive from raw log chunk
|
||||
chunk = raw_bytes[prev_pos:pos]
|
||||
cleaned = dummy_node._logclean(chunk.decode(errors='replace'), var=True)
|
||||
lines = [l for l in cleaned.split('\n') if l.strip()]
|
||||
preview = lines[-1].strip() if lines else ""
|
||||
|
||||
if preview:
|
||||
match = prompt_re.search(preview)
|
||||
if match:
|
||||
cmd_text = preview[match.end():].strip()
|
||||
if cmd_text:
|
||||
blocks.append((pos, preview[:80]))
|
||||
|
||||
clean_buffer = dummy_node._logclean(raw_bytes.decode(errors='replace'), var=True)
|
||||
last_line = clean_buffer.split('\n')[-1].strip() if clean_buffer.strip() else "(prompt)"
|
||||
blocks.append((len(raw_bytes), last_line[:80]))
|
||||
|
||||
context_cmd = [1]
|
||||
total_cmds = len(blocks)
|
||||
total_lines = len(clean_buffer.split('\n'))
|
||||
context_lines = [min(50, total_lines)]
|
||||
context_mode = [0]
|
||||
MODE_RANGE, MODE_SINGLE, MODE_LINES = 0, 1, 2
|
||||
|
||||
bindings = KeyBindings()
|
||||
|
||||
@bindings.add('c-up')
|
||||
def _(event):
|
||||
if context_mode[0] == MODE_LINES:
|
||||
if context_lines[0] >= total_lines:
|
||||
context_lines[0] = min(50, total_lines)
|
||||
else:
|
||||
context_lines[0] = min(context_lines[0] + 50, total_lines)
|
||||
else:
|
||||
if context_cmd[0] < total_cmds:
|
||||
context_cmd[0] += 1
|
||||
else:
|
||||
context_cmd[0] = 1
|
||||
event.app.invalidate()
|
||||
|
||||
@bindings.add('c-down')
|
||||
def _(event):
|
||||
if context_mode[0] == MODE_LINES:
|
||||
if context_lines[0] <= min(50, total_lines):
|
||||
context_lines[0] = total_lines
|
||||
else:
|
||||
context_lines[0] = max(context_lines[0] - 50, min(50, total_lines))
|
||||
else:
|
||||
if context_cmd[0] > 1:
|
||||
context_cmd[0] -= 1
|
||||
else:
|
||||
context_cmd[0] = total_cmds
|
||||
event.app.invalidate()
|
||||
|
||||
@bindings.add('tab')
|
||||
def _(event):
|
||||
context_mode[0] = (context_mode[0] + 1) % 3
|
||||
event.app.invalidate()
|
||||
|
||||
@bindings.add('escape', eager=True)
|
||||
def _(event):
|
||||
event.app.exit(result='')
|
||||
|
||||
def get_current_block():
|
||||
idx = max(0, total_cmds - context_cmd[0])
|
||||
return idx, blocks[idx]
|
||||
|
||||
def get_active_buffer():
|
||||
if context_mode[0] == MODE_LINES:
|
||||
buffer_lines = clean_buffer.split('\n')
|
||||
return '\n'.join(buffer_lines[-context_lines[0]:])
|
||||
|
||||
idx, (start, preview) = get_current_block()
|
||||
if context_mode[0] == MODE_SINGLE and idx + 1 < total_cmds:
|
||||
end = blocks[idx + 1][0]
|
||||
active_raw = raw_bytes[start:end]
|
||||
else:
|
||||
active_raw = raw_bytes[start:]
|
||||
return preview + "\n" + dummy_node._logclean(active_raw.decode(errors='replace'), var=True)
|
||||
|
||||
def get_prompt_text():
|
||||
if context_mode[0] == MODE_LINES:
|
||||
return HTML(f"<ansicyan>Ask [Ctx: {context_lines[0]}/{total_lines}L]: </ansicyan>")
|
||||
|
||||
lines_count = len(get_active_buffer().split('\n'))
|
||||
if context_mode[0] == MODE_SINGLE:
|
||||
return HTML(f"<ansicyan>Ask [Cmd {context_cmd[0]} ~{lines_count}L]: </ansicyan>")
|
||||
else:
|
||||
return HTML(f"<ansicyan>Ask [Cmd {context_cmd[0]}\u2192END ~{lines_count}L]: </ansicyan>")
|
||||
|
||||
def get_toolbar():
|
||||
mode_labels = {MODE_RANGE: "RANGE", MODE_SINGLE: "SINGLE", MODE_LINES: "LINES"}
|
||||
mode_label = mode_labels[context_mode[0]]
|
||||
if context_mode[0] == MODE_LINES:
|
||||
return HTML(f"<ansigray>\u25b6 Ctrl+\u2191/\u2193 adjusts by 50 lines [Tab: {mode_label}]</ansigray>")
|
||||
_, (_, preview) = get_current_block()
|
||||
return HTML(f"<ansigray>\u25b6 {preview} [Tab: {mode_label}]</ansigray>")
|
||||
|
||||
try:
|
||||
session = PromptSession(history=self.copilot_history)
|
||||
question = session.prompt(get_prompt_text, key_bindings=bindings, bottom_toolbar=get_toolbar)
|
||||
except KeyboardInterrupt:
|
||||
question = ""
|
||||
|
||||
if not question or not question.strip() or question.strip() == "CANCEL":
|
||||
console.print("\n[dim]Copilot cancelled.[/dim]")
|
||||
request_queue.put(connpy_pb2.InteractRequest(copilot_question="CANCEL"))
|
||||
resume_generator()
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
continue
|
||||
|
||||
active_buffer = get_active_buffer()
|
||||
# Enrich question with history (same as local CLI)
|
||||
past_questions = self.copilot_history.get_strings()
|
||||
if len(past_questions) > 1:
|
||||
# Limit history to last 5 questions to save tokens, excluding current
|
||||
recent_history = past_questions[-6:-1]
|
||||
history_text = "\n".join(f"- {q}" for q in recent_history)
|
||||
enriched_question = f"Previous questions in this session:\n{history_text}\n\nCurrent Question:\n{question}"
|
||||
else:
|
||||
enriched_question = question
|
||||
|
||||
request_queue.put(connpy_pb2.InteractRequest(copilot_question=enriched_question, copilot_context_buffer=active_buffer))
|
||||
|
||||
from rich.live import Live
|
||||
live_text = "Thinking..."
|
||||
panel = Panel(live_text, title="[bold cyan]Copilot Guide[/bold cyan]", border_style="cyan")
|
||||
result = {}
|
||||
cancelled = False
|
||||
|
||||
with copilot_terminal_mode(), Live(panel, console=console, refresh_per_second=10) as live:
|
||||
import fcntl
|
||||
flags = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
|
||||
fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
|
||||
while True:
|
||||
try:
|
||||
key = os.read(sys.stdin.fileno(), 1024)
|
||||
if b'\x03' in key or b'\x1b' in key:
|
||||
cancelled = True
|
||||
request_queue.put(connpy_pb2.InteractRequest(copilot_question="CANCEL"))
|
||||
msg = "Ctrl+C" if b'\x03' in key else "Esc"
|
||||
console.print(f"\n[dim]Copilot cancelled via {msg}.[/dim]")
|
||||
break
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
try:
|
||||
chunk_res = response_queue.get(timeout=0.1)
|
||||
if chunk_res is None:
|
||||
break
|
||||
|
||||
if chunk_res.copilot_stream_chunk:
|
||||
if live_text == "Thinking...": live_text = ""
|
||||
live_text += chunk_res.copilot_stream_chunk
|
||||
live.update(Panel(Markdown(live_text), title="[bold cyan]Copilot Guide[/bold cyan]", border_style="cyan"))
|
||||
elif chunk_res.copilot_response_json:
|
||||
result = json.loads(chunk_res.copilot_response_json)
|
||||
break
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, flags)
|
||||
|
||||
if cancelled:
|
||||
resume_generator()
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
continue
|
||||
|
||||
if result.get("error"):
|
||||
console.print(f"[red]Error: {result['error']}[/red]")
|
||||
request_queue.put(connpy_pb2.InteractRequest(copilot_action="cancel"))
|
||||
resume_generator()
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
continue
|
||||
|
||||
if live_text == "Thinking..." and result.get("guide"):
|
||||
console.print(Panel(Markdown(result["guide"]), title="[bold cyan]Copilot Guide[/bold cyan]", border_style="cyan"))
|
||||
|
||||
commands = result.get("commands", [])
|
||||
risk = result.get("risk_level", "low")
|
||||
risk_style = {"low": "green", "high": "yellow", "destructive": "red"}.get(risk, "green")
|
||||
|
||||
action_sent = "cancel"
|
||||
if commands:
|
||||
cmd_text = "\n".join(f" {i+1}. {cmd}" for i, cmd in enumerate(commands))
|
||||
console.print(Panel(
|
||||
cmd_text,
|
||||
title=f"[bold {risk_style}]Suggested Commands [{risk.upper()}][/bold {risk_style}]",
|
||||
border_style=risk_style
|
||||
))
|
||||
|
||||
try:
|
||||
confirm_session = PromptSession()
|
||||
confirm_bindings = KeyBindings()
|
||||
@confirm_bindings.add('escape', eager=True)
|
||||
def _(event):
|
||||
event.app.exit(result='n')
|
||||
|
||||
pt_color = "ansi" + risk_style
|
||||
action = confirm_session.prompt(
|
||||
HTML(f"<{pt_color}>Send commands? (y/n/e/number/range) [n]: </{pt_color}>"),
|
||||
key_bindings=confirm_bindings
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
action = "n"
|
||||
|
||||
if not action.strip():
|
||||
action = "n"
|
||||
|
||||
action_l = action.lower().strip()
|
||||
if action_l in ('y', 'yes', 'all'):
|
||||
action_sent = "send_all"
|
||||
elif action_l.startswith('e'):
|
||||
action_sent = f"edit_{action_l[1:]}" if len(action_l) > 1 else "edit_all"
|
||||
# For remote editing, the client edits and sends back as custom action
|
||||
edit_session = PromptSession()
|
||||
cmds_to_edit = []
|
||||
if action_sent.startswith("edit_") and action_sent[5:].isdigit():
|
||||
idx = int(action_sent[5:]) - 1
|
||||
if 0 <= idx < len(commands):
|
||||
cmds_to_edit = [commands[idx]]
|
||||
else:
|
||||
cmds_to_edit = commands
|
||||
|
||||
if cmds_to_edit:
|
||||
target_cmd = "\n".join(cmds_to_edit)
|
||||
try:
|
||||
edit_bindings = KeyBindings()
|
||||
@edit_bindings.add('c-j')
|
||||
def _(event):
|
||||
event.app.exit(result=event.app.current_buffer.text)
|
||||
@edit_bindings.add('escape', eager=True)
|
||||
def _(event):
|
||||
event.app.exit(result='')
|
||||
|
||||
edited_cmd = edit_session.prompt(
|
||||
HTML("<ansicyan>Edit commands (Ctrl+Enter to submit, Esc to cancel):\n</ansicyan>"),
|
||||
default=target_cmd,
|
||||
multiline=True,
|
||||
key_bindings=edit_bindings
|
||||
)
|
||||
if edited_cmd.strip():
|
||||
action_sent = "custom:" + edited_cmd.strip()
|
||||
else:
|
||||
action_sent = "cancel"
|
||||
except KeyboardInterrupt:
|
||||
action_sent = "cancel"
|
||||
elif action_l not in ('n', 'no', ''):
|
||||
action_sent = action_l
|
||||
|
||||
console.print("[dim]Returning to session...[/dim]\n")
|
||||
request_queue.put(connpy_pb2.InteractRequest(copilot_action=action_sent))
|
||||
resume_generator()
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
self._handle_remote_copilot(
|
||||
res, request_queue, response_queue,
|
||||
client_buffer_bytes, cmd_byte_positions,
|
||||
pause_generator, resume_generator, old_tty
|
||||
)
|
||||
continue
|
||||
|
||||
if res.copilot_injected_command:
|
||||
@@ -1496,6 +927,17 @@ class AIStub:
|
||||
req = connpy_pb2.ProviderRequest(provider=provider, model=model or "", api_key=api_key or "")
|
||||
self.stub.configure_provider(req)
|
||||
|
||||
@handle_errors
|
||||
def configure_mcp(self, name, url=None, enabled=True, auto_load_on_os=None, remove=False):
|
||||
req = connpy_pb2.MCPRequest(
|
||||
name=name,
|
||||
url=url or "",
|
||||
enabled=enabled,
|
||||
auto_load_on_os=auto_load_on_os or "",
|
||||
remove=remove
|
||||
)
|
||||
self.stub.configure_mcp(req)
|
||||
|
||||
@handle_errors
|
||||
def load_session_data(self, session_id):
|
||||
return from_struct(self.stub.load_session_data(connpy_pb2.StringRequest(value=session_id)).data)
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
MCP_AVAILABLE = True
|
||||
except ImportError:
|
||||
MCP_AVAILABLE = False
|
||||
|
||||
# Silence noisy MCP and HTTP internal logging
|
||||
logging.getLogger("mcp").setLevel(logging.CRITICAL)
|
||||
logging.getLogger("httpx").setLevel(logging.CRITICAL)
|
||||
logging.getLogger("httpcore").setLevel(logging.CRITICAL)
|
||||
|
||||
class MCPClientManager:
|
||||
"""Manages MCP SSE client connections for connpy."""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(MCPClientManager, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config=None):
|
||||
if self._initialized:
|
||||
return
|
||||
self.config = config
|
||||
self.sessions: Dict[str, Dict[str, Any]] = {} # name -> {session, stack}
|
||||
self.tool_cache: Dict[str, List[Dict[str, Any]]] = {}
|
||||
self._connecting: Dict[str, asyncio.Future] = {}
|
||||
self._initialized = True
|
||||
|
||||
async def get_tools_for_llm(self, os_filter: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetches tools from enabled MCP servers that match the OS filter.
|
||||
"""
|
||||
if not MCP_AVAILABLE:
|
||||
return []
|
||||
|
||||
all_llm_tools = []
|
||||
try:
|
||||
mcp_config = self.config.config.get("ai", {}).get("mcp_servers", {})
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
async def _fetch(name, cfg):
|
||||
if not cfg.get("enabled", True): return []
|
||||
|
||||
# Filter by OS if specified in config (primarily used for copilot strict matching)
|
||||
auto_os = cfg.get("auto_load_on_os")
|
||||
if os_filter is not None and auto_os and os_filter.lower() != auto_os.lower():
|
||||
return []
|
||||
|
||||
try:
|
||||
session = await self._ensure_connected(name, cfg)
|
||||
if session:
|
||||
if name in self.tool_cache: return self.tool_cache[name]
|
||||
llm_tools = await self._fetch_tools_as_openai(name, session)
|
||||
self.tool_cache[name] = llm_tools
|
||||
return llm_tools
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
tasks = [ _fetch(name, cfg) for name, cfg in mcp_config.items() ]
|
||||
|
||||
if tasks:
|
||||
results = await asyncio.gather(*tasks)
|
||||
for tools in results:
|
||||
all_llm_tools.extend(tools)
|
||||
|
||||
return all_llm_tools
|
||||
|
||||
async def _ensure_connected(self, name: str, cfg: Dict[str, Any]) -> Optional[Any]:
|
||||
if not MCP_AVAILABLE: return None
|
||||
|
||||
if name in self.sessions and self.sessions[name].get("session"):
|
||||
return self.sessions[name]["session"]
|
||||
|
||||
url = cfg.get("url")
|
||||
if not url:
|
||||
return None
|
||||
|
||||
if name in self._connecting:
|
||||
try:
|
||||
return await asyncio.wait_for(asyncio.shield(self._connecting[name]), timeout=10.0)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
fut = loop.create_future()
|
||||
self._connecting[name] = fut
|
||||
|
||||
try:
|
||||
from contextlib import AsyncExitStack
|
||||
stack = AsyncExitStack()
|
||||
|
||||
async def _do_connect():
|
||||
read, write = await stack.enter_async_context(sse_client(url))
|
||||
session = await stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
return session
|
||||
|
||||
session = await asyncio.wait_for(_do_connect(), timeout=15.0)
|
||||
self.sessions[name] = {"session": session, "stack": stack}
|
||||
fut.set_result(session)
|
||||
return session
|
||||
except Exception:
|
||||
fut.set_result(None)
|
||||
return None
|
||||
finally:
|
||||
if name in self._connecting:
|
||||
del self._connecting[name]
|
||||
|
||||
async def _fetch_tools_as_openai(self, server_name: str, session: Any) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
result = await asyncio.wait_for(session.list_tools(), timeout=5.0)
|
||||
openai_tools = []
|
||||
for tool in result.tools:
|
||||
# Use mcp_ prefix to ensure valid function name for LiteLLM/Gemini
|
||||
prefixed_name = f"mcp_{server_name}__{tool.name}"
|
||||
openai_tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": prefixed_name,
|
||||
"description": f"[{server_name}] {tool.description}",
|
||||
"parameters": tool.inputSchema
|
||||
}
|
||||
})
|
||||
return openai_tools
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
async def call_tool(self, full_tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
"""Calls an MCP tool and returns text result."""
|
||||
if not MCP_AVAILABLE:
|
||||
return "Error: MCP SDK is not installed."
|
||||
|
||||
if "__" not in full_tool_name:
|
||||
return f"Error: Tool {full_tool_name} is not a valid MCP tool."
|
||||
|
||||
clean_name = full_tool_name[4:] if full_tool_name.startswith("mcp_") else full_tool_name
|
||||
server_name, tool_name = clean_name.split("__", 1)
|
||||
|
||||
if server_name not in self.sessions:
|
||||
return f"Error: MCP server {server_name} is not connected."
|
||||
|
||||
session = self.sessions[server_name]["session"]
|
||||
try:
|
||||
result = await asyncio.wait_for(session.call_tool(tool_name, arguments), timeout=60.0)
|
||||
text_outputs = [content.text for content in result.content if hasattr(content, "text")]
|
||||
return "\n".join(text_outputs) if text_outputs else str(result)
|
||||
except Exception as e:
|
||||
return f"Error calling tool {tool_name} on {server_name}: {str(e)}"
|
||||
|
||||
async def shutdown(self):
|
||||
"""Close all SSE connections."""
|
||||
for name, data in self.sessions.items():
|
||||
stack = data.get("stack")
|
||||
if stack:
|
||||
await stack.aclose()
|
||||
self.sessions = {}
|
||||
@@ -69,6 +69,7 @@ service AIService {
|
||||
rpc list_sessions (google.protobuf.Empty) returns (ValueResponse) {}
|
||||
rpc delete_session (StringRequest) returns (google.protobuf.Empty) {}
|
||||
rpc configure_provider (ProviderRequest) returns (google.protobuf.Empty) {}
|
||||
rpc configure_mcp (MCPRequest) returns (google.protobuf.Empty) {}
|
||||
rpc load_session_data (StringRequest) returns (StructResponse) {}
|
||||
}
|
||||
|
||||
@@ -282,3 +283,11 @@ message CopilotResponse {
|
||||
string risk_level = 3;
|
||||
string error = 4;
|
||||
}
|
||||
|
||||
message MCPRequest {
|
||||
string name = 1;
|
||||
string url = 2;
|
||||
bool enabled = 3;
|
||||
string auto_load_on_os = 4;
|
||||
bool remove = 5;
|
||||
}
|
||||
|
||||
@@ -60,6 +60,40 @@ class AIService(BaseService):
|
||||
self.config.config["ai"] = settings
|
||||
self.config._saveconfig(self.config.file)
|
||||
|
||||
def configure_mcp(self, name, url=None, enabled=None, auto_load_on_os=None, remove=False):
|
||||
"""Update MCP server settings in the configuration with smart merging."""
|
||||
ai_settings = self.config.config.get("ai", {})
|
||||
mcp_servers = ai_settings.get("mcp_servers", {})
|
||||
|
||||
if remove:
|
||||
if name in mcp_servers:
|
||||
del mcp_servers[name]
|
||||
else:
|
||||
# Get existing or new
|
||||
server_cfg = mcp_servers.get(name, {})
|
||||
|
||||
# Partial updates
|
||||
if url is not None:
|
||||
server_cfg["url"] = url
|
||||
|
||||
if enabled is not None:
|
||||
server_cfg["enabled"] = bool(enabled)
|
||||
elif "enabled" not in server_cfg:
|
||||
server_cfg["enabled"] = True # Default for new entries
|
||||
|
||||
if auto_load_on_os is not None:
|
||||
if auto_load_on_os == "": # Explicit clear
|
||||
if "auto_load_on_os" in server_cfg:
|
||||
del server_cfg["auto_load_on_os"]
|
||||
else:
|
||||
server_cfg["auto_load_on_os"] = auto_load_on_os
|
||||
|
||||
mcp_servers[name] = server_cfg
|
||||
|
||||
ai_settings["mcp_servers"] = mcp_servers
|
||||
self.config.config["ai"] = ai_settings
|
||||
self.config._saveconfig(self.config.file)
|
||||
|
||||
def load_session_data(self, session_id):
|
||||
"""Load a session's raw data by ID."""
|
||||
from connpy.ai import ai
|
||||
|
||||
@@ -14,7 +14,7 @@ class ExecutionService(BaseService):
|
||||
commands: List[str],
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
parallel: int = 10,
|
||||
timeout: int = 10,
|
||||
timeout: int = 20,
|
||||
folder: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
on_node_complete: Optional[Callable] = None,
|
||||
@@ -62,7 +62,7 @@ class ExecutionService(BaseService):
|
||||
expected: List[str],
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
parallel: int = 10,
|
||||
timeout: int = 10,
|
||||
timeout: int = 20,
|
||||
folder: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
on_node_complete: Optional[Callable] = None,
|
||||
@@ -139,7 +139,7 @@ class ExecutionService(BaseService):
|
||||
"commands": playbook["commands"],
|
||||
"variables": playbook.get("variables"),
|
||||
"parallel": options.get("parallel", parallel),
|
||||
"timeout": playbook.get("timeout", options.get("timeout", 10)),
|
||||
"timeout": playbook.get("timeout", options.get("timeout", 20)),
|
||||
"prompt": options.get("prompt"),
|
||||
"name": playbook.get("name", "Task")
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
@@ -11,12 +11,23 @@ class DummyConfig:
|
||||
self.config = {"ai": {"engineer_api_key": "test_key", "engineer_model": "test_model"}}
|
||||
self.defaultdir = "/tmp"
|
||||
|
||||
class MockAsyncIterator:
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
def __aiter__(self):
|
||||
return self
|
||||
async def __anext__(self):
|
||||
if not self.items:
|
||||
raise StopAsyncIteration
|
||||
return self.items.pop(0)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion():
|
||||
with patch('connpy.ai.completion') as mock:
|
||||
def mock_acompletion():
|
||||
# Patch acompletion inside connpy.ai.aask_copilot
|
||||
with patch('litellm.acompletion') as mock:
|
||||
yield mock
|
||||
|
||||
def test_ask_copilot_tool_call(mock_completion):
|
||||
def test_aask_copilot_tool_call(mock_acompletion):
|
||||
agent = ai(DummyConfig())
|
||||
|
||||
# Setup mock response for streaming
|
||||
@@ -32,13 +43,20 @@ def test_ask_copilot_tool_call(mock_completion):
|
||||
def __init__(self, content):
|
||||
self.choices = [MockChoice(content)]
|
||||
|
||||
mock_completion.return_value = [
|
||||
MockChunk("<guide>Check the interfaces and running config.</guide>"),
|
||||
MockChunk("<commands>\nshow ip int br\nshow run\n</commands>"),
|
||||
MockChunk("<risk>low</risk>")
|
||||
]
|
||||
# acompletion is awaited and returns an async iterator
|
||||
async def mock_ac(*args, **kwargs):
|
||||
return MockAsyncIterator([
|
||||
MockChunk("<guide>Check the interfaces and running config.</guide>"),
|
||||
MockChunk("<commands>\nshow ip int br\nshow run\n</commands>"),
|
||||
MockChunk("<risk>low</risk>")
|
||||
])
|
||||
|
||||
result = agent.ask_copilot("Router#", "What do I do?")
|
||||
mock_acompletion.side_effect = mock_ac
|
||||
|
||||
async def run_test():
|
||||
return await agent.aask_copilot("Router#", "What do I do?")
|
||||
|
||||
result = asyncio.run(run_test())
|
||||
|
||||
if result["error"]:
|
||||
print(f"ERROR OCCURRED: {result['error']}")
|
||||
@@ -48,7 +66,7 @@ def test_ask_copilot_tool_call(mock_completion):
|
||||
assert result["risk_level"] == "low"
|
||||
assert result["commands"] == ["show ip int br", "show run"]
|
||||
|
||||
def test_ask_copilot_fallback(mock_completion):
|
||||
def test_aask_copilot_fallback(mock_acompletion):
|
||||
agent = ai(DummyConfig())
|
||||
|
||||
# Setup mock response for streaming
|
||||
@@ -64,11 +82,17 @@ def test_ask_copilot_fallback(mock_completion):
|
||||
def __init__(self, content):
|
||||
self.choices = [MockChoice(content)]
|
||||
|
||||
mock_completion.return_value = [
|
||||
MockChunk("Here is some text response instead of tool call.")
|
||||
]
|
||||
async def mock_ac(*args, **kwargs):
|
||||
return MockAsyncIterator([
|
||||
MockChunk("Here is some text response instead of tool call.")
|
||||
])
|
||||
|
||||
result = agent.ask_copilot("Router#", "What do I do?")
|
||||
mock_acompletion.side_effect = mock_ac
|
||||
|
||||
async def run_test():
|
||||
return await agent.aask_copilot("Router#", "What do I do?")
|
||||
|
||||
result = asyncio.run(run_test())
|
||||
|
||||
if result["error"]:
|
||||
print(f"ERROR OCCURRED: {result['error']}")
|
||||
|
||||
@@ -47,6 +47,24 @@ class LocalStream:
|
||||
# signal handling not supported on some loops (e.g., Windows Proactor)
|
||||
pass
|
||||
|
||||
def stop_reading(self):
|
||||
"""Temporarily stop reading from stdin."""
|
||||
if self._loop and self.stdin_fd is not None:
|
||||
try:
|
||||
self._loop.remove_reader(self.stdin_fd)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def start_reading(self):
|
||||
"""Resume reading from stdin."""
|
||||
if self._loop and self.stdin_fd is not None:
|
||||
try:
|
||||
# Ensure we don't add it twice
|
||||
self._loop.remove_reader(self.stdin_fd)
|
||||
except Exception:
|
||||
pass
|
||||
self._loop.add_reader(self.stdin_fd, self._read_ready)
|
||||
|
||||
def teardown(self):
|
||||
if self._loop:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user