diff --git a/src/drain.rs b/src/drain.rs new file mode 100644 index 0000000..7ec6728 --- /dev/null +++ b/src/drain.rs @@ -0,0 +1,91 @@ +use std::collections::{HashMap, HashSet}; + +use k8s_openapi::api::core::v1::{Node, Pod}; +use kube::Client; +use kube::api::{Api, ListParams, WatchEvent, WatchParams}; +use rocket::futures::{StreamExt, TryStreamExt}; +use tracing::{debug, info, trace}; + +pub(crate) async fn drain_node( + client: Client, + name: &str, +) -> Result<(), kube::Error> { + let all_pods: Api = Api::all(client.clone()); + let filter = &format!("spec.nodeName={name}"); + let mut node_pods: HashSet<_> = all_pods + .list(&ListParams::default().fields(filter)) + .await? + .items + .into_iter() + .filter_map(|p| { + let name = p.metadata.name?; + let namespace = p.metadata.namespace?; + let owners = p.metadata.owner_references.unwrap_or_default(); + + if owners.iter().any(|o| o.kind == "DaemonSet") { + info!("Ignoring DaemonSet pod {name}"); + None + } else { + Some((namespace, name)) + } + }) + .collect(); + if node_pods.is_empty() { + debug!("No pods to evict from node {name}"); + return Ok(()); + } + let mut pods = HashMap::new(); + for (namespace, name) in node_pods.iter() { + info!("Evicting pod {namespace}/{name}"); + let api = pods + .entry(namespace) + .or_insert_with_key(|k| Api::::namespaced(client.clone(), k)); + // Return early here because otherwise we would just wait forever for + // the pod to be deleted. + api.evict(name, &Default::default()).await?; + } + let mut stream = all_pods + .watch(&WatchParams::default().fields(filter), "0") + .await? + .boxed(); + while let Some(event) = stream.try_next().await? { + trace!("Watch pod event: {event:?}"); + if let WatchEvent::Deleted(pod) = event { + if let (Some(namespace), Some(name)) = + (pod.metadata.namespace, pod.metadata.name) + { + node_pods.remove(&(namespace, name)); + } + let n = node_pods.len(); + if n == 0 { + break; + } + debug!( + "Waiting for {n} more {}", + if n == 1 { "pod" } else { "pods" } + ); + } + } + info!("Finished draining pods from {name}"); + Ok(()) +} + +pub(crate) async fn cordon_node( + client: Client, + name: &str, +) -> Result<(), kube::Error> { + let nodes: Api = Api::all(client); + info!("Cordoning node: {name}"); + nodes.cordon(name).await?; + Ok(()) +} + +pub(crate) async fn uncordon_node( + client: Client, + name: &str, +) -> Result<(), kube::Error> { + let nodes: Api = Api::all(client); + info!("Uncordoning node: {name}"); + nodes.uncordon(name).await?; + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index 9d90bc8..73e6ee3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +mod drain; mod lock; use rocket::Request; diff --git a/src/lock.rs b/src/lock.rs index 0939218..3e2ec2f 100644 --- a/src/lock.rs +++ b/src/lock.rs @@ -8,6 +8,8 @@ use rocket::http::Status; use rocket::request::{self, FromRequest, Request}; use tracing::{error, info, trace, warn}; +use crate::drain; + #[derive(Debug, rocket::Responder)] pub enum LockError { #[response(status = 500, content_type = "plain")] @@ -159,6 +161,13 @@ pub async fn lock_v1( Err(e) => return Err(e.into()), } } + if let Err(e) = drain::cordon_node(client.clone(), &data.hostname).await { + error!("Failed to cordon node {}: {e}", data.hostname); + } else if let Err(e) = + drain::drain_node(client.clone(), &data.hostname).await + { + error!("Failed to drain node {}: {e}", data.hostname); + } Ok(format!( "Acquired reboot lock for group {}, host {}\n", data.group, data.hostname @@ -177,5 +186,9 @@ pub async fn unlock_v1( })?; let lock_name = format!("reboot-lock-{}", data.group); update_lease(client.clone(), &lock_name, &data.hostname, None).await?; + if let Err(e) = drain::uncordon_node(client.clone(), &data.hostname).await + { + error!("Failed to uncordon node {}: {e}", data.hostname); + } Ok(()) } diff --git a/tests/integration/lock.rs b/tests/integration/lock.rs index 77f5ccf..9b4cd06 100644 --- a/tests/integration/lock.rs +++ b/tests/integration/lock.rs @@ -1,8 +1,9 @@ use std::sync::LazyLock; use k8s_openapi::api::coordination::v1::Lease; +use k8s_openapi::api::core::v1::{Node, Pod}; use kube::Client; -use kube::api::Api; +use kube::api::{Api, ListParams}; use rocket::async_test; use rocket::futures::FutureExt; use rocket::http::{ContentType, Header, Status}; @@ -28,6 +29,27 @@ async fn get_lease(name: &str) -> Result { leases.get(name).await } +async fn get_a_node() -> Result { + let client = Client::try_default().await?; + let nodes: Api = Api::all(client); + Ok(nodes.list(&Default::default()).await?.items.pop().unwrap()) +} + +async fn get_node_by_name(name: &str) -> Result { + let client = Client::try_default().await?; + let nodes: Api = Api::all(client); + nodes.get(name).await +} + +async fn get_pods_on_node(name: &str) -> Result, kube::Error> { + let client = Client::try_default().await?; + let pods: Api = Api::all(client); + Ok(pods + .list(&ListParams::default().fields(&format!("spec.nodeName=={name}"))) + .await? + .items) +} + #[async_test] async fn test_lock_v1_success() { super::setup(); @@ -349,3 +371,91 @@ fn test_unlock_v1_no_data() { Some("Error processing request:\nhostname: missing\n") ); } + +#[async_test] +async fn test_lock_v1_drain() { + super::setup(); + let _lock = &*LOCK.lock().await; + + delete_lease("reboot-lock-default").await; + let node = get_a_node().await.unwrap(); + let hostname = node.metadata.name.clone().unwrap(); + let client = super::async_client().await; + let response = client + .post("/api/v1/lock") + .header(Header::new("K8s-Reboot-Lock", "lock")) + .header(ContentType::Form) + .body(format!("hostname={hostname}")) + .dispatch() + .await; + let status = response.status(); + assert_eq!( + response.into_string().await, + Some(format!( + "Acquired reboot lock for group default, host {hostname}\n" + )) + ); + assert_eq!(status, Status::Ok); + let lease = get_lease("reboot-lock-default").await.unwrap(); + assert_eq!( + lease.spec.unwrap().holder_identity.as_ref(), + Some(&hostname) + ); + let node = get_node_by_name(&hostname).await.unwrap(); + assert!( + node.spec + .unwrap() + .taints + .unwrap() + .iter() + .any(|t| t.key == "node.kubernetes.io/unschedulable" + && t.effect == "NoSchedule") + ); + let pods = get_pods_on_node(&hostname).await.unwrap(); + assert_eq!( + pods.iter() + .filter(|p| { + !p.metadata + .owner_references + .clone() + .unwrap_or_default() + .iter() + .any(|o| o.kind == "DaemonSet") + }) + .count(), + 0 + ); +} + +#[async_test] +async fn test_unlock_v1_uncordon() { + super::setup(); + let _lock = &*LOCK.lock().await; + + let node = get_a_node().await.unwrap(); + let hostname = node.metadata.name.clone().unwrap(); + let client = super::async_client().await; + let response = client + .post("/api/v1/unlock") + .header(Header::new("K8s-Reboot-Lock", "lock")) + .header(ContentType::Form) + .body(format!("hostname={hostname}")) + .dispatch() + .await; + let status = response.status(); + assert_eq!(response.into_string().await, None,); + assert_eq!(status, Status::Ok); + let lease = get_lease("reboot-lock-default").await.unwrap(); + assert_eq!(lease.spec.unwrap().holder_identity, None); + let node = get_node_by_name(&hostname).await.unwrap(); + assert!( + !node + .spec + .unwrap() + .taints + .unwrap_or_default() + .iter() + .any(|t| t.key == "node.kubernetes.io/unschedulable" + && t.effect == "NoSchedule") + ); +}