From 6cabad14d010d58c1e86fd6758303de85e5fbb8a Mon Sep 17 00:00:00 2001
From: douglas-raillard-arm <douglas.raillard@gmail.com>
Date: Tue, 23 Mar 2021 11:15:35 +0000
Subject: [PATCH] connection: Make ConnectionBase.cancel() more robust

Check with poll() if the command is already finished first, to avoid
sending SIGKILL to unrelated processes due to PID recycling.

The race window still exists between the call to poll() and _cancel(),
but is reduced a great deal.
---
 devlib/connection.py | 16 ++++++++++++----
 1 file changed, 12 insertions(+), 4 deletions(-)

diff --git a/devlib/connection.py b/devlib/connection.py
index 17325db..97e72b2 100644
--- a/devlib/connection.py
+++ b/devlib/connection.py
@@ -105,12 +105,20 @@ class BackgroundCommand(ABC):
         """
         self.send_signal(signal.SIGKILL)
 
-    @abstractmethod
     def cancel(self, kill_timeout=_KILL_TIMEOUT):
         """
         Try to gracefully terminate the process by sending ``SIGTERM``, then
         waiting for ``kill_timeout`` to send ``SIGKILL``.
         """
+        if self.poll() is None:
+            self._cancel(kill_timeout=kill_timeout)
+
+    @abstractmethod
+    def _cancel(self, kill_timeout):
+        """
+        Method to override in subclasses to implement :meth:`cancel`.
+        """
+        pass
 
     @abstractmethod
     def wait(self):
@@ -209,7 +217,7 @@ class PopenBackgroundCommand(BackgroundCommand):
     def poll(self):
         return self.popen.poll()
 
-    def cancel(self, kill_timeout=_KILL_TIMEOUT):
+    def _cancel(self, kill_timeout):
         popen = self.popen
         os.killpg(os.getpgid(popen.pid), signal.SIGTERM)
         try:
@@ -266,7 +274,7 @@ class ParamikoBackgroundCommand(BackgroundCommand):
         else:
             return None
 
-    def cancel(self, kill_timeout=_KILL_TIMEOUT):
+    def _cancel(self, kill_timeout):
         self.send_signal(signal.SIGTERM)
         # Check if the command terminated quickly
         time.sleep(10e-3)
@@ -340,7 +348,7 @@ class AdbBackgroundCommand(BackgroundCommand):
     def poll(self):
         return self.adb_popen.poll()
 
-    def cancel(self, kill_timeout=_KILL_TIMEOUT):
+    def _cancel(self, kill_timeout):
         self.send_signal(signal.SIGTERM)
         try:
             self.adb_popen.wait(timeout=_KILL_TIMEOUT)