1#![allow(clippy::type_complexity)]
10
11use std::{
12 collections::{btree_map, BTreeMap},
13 sync::{
14 atomic::{AtomicU32, Ordering},
15 Arc,
16 },
17};
18
19use futures::{pin_mut, FutureExt};
20use pinnacle_api_defs::pinnacle::signal::v1::{SignalRequest, StreamControl};
21use tokio::sync::{
22 mpsc::{unbounded_channel, UnboundedSender},
23 oneshot,
24};
25use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt};
26use tonic::Streaming;
27
28use crate::{
29 input::libinput::DeviceHandle, output::OutputHandle, tag::TagHandle, window::WindowHandle,
30 BlockOnTokio,
31};
32
33pub(crate) trait Signal {
34 type Callback;
35}
36
37macro_rules! signals {
38 ( $(
39 $( #[$cfg_enum:meta] )* $enum:ident => {
40 $(
41 $( #[$cfg:meta] )* $name:ident = {
42 enum_name = $renamed:ident,
43 callback_type = $cb:ty,
44 client_request = $req:ident,
45 on_response = $on_resp:expr,
46 }
47 )*
48 }
49 )* ) => {$(
50 $(
51 $( #[$cfg] )*
52 pub(crate) struct $name;
53
54 impl $crate::signal::Signal for $name {
55 type Callback = $cb;
56 }
57
58 impl SignalData<$name> {
59 pub(crate) fn add_callback(&mut self, callback: <$name as Signal>::Callback) -> SignalHandle {
60 if self.callback_count.load(::std::sync::atomic::Ordering::SeqCst) == 0 {
61 self.connect()
62 }
63
64 let Some(callback_sender) = self.callback_sender.as_ref() else {
65 unreachable!("signal should already be connected here");
66 };
67
68 let Some(remove_callback_sender) = self.remove_callback_sender.clone() else {
69 unreachable!("signal should already be connected here");
70 };
71
72 callback_sender
73 .send((self.current_id, callback))
74 .expect("failed to send callback");
75
76 self.callback_count.fetch_add(1, Ordering::SeqCst);
77
78 let handle = SignalHandle::new(self.current_id, remove_callback_sender);
79
80 self.current_id.0 += 1;
81
82 handle
83 }
84
85 fn reset(&mut self) {
86 self.callback_sender.take();
87 self.dc_pinger.take();
88 self.remove_callback_sender.take();
89 self.callback_count.store(0, Ordering::SeqCst);
90 self.current_id = SignalConnId::default();
91 }
92
93 fn connect(&mut self) {
94 self.reset();
95
96 let channels = connect_signal::<_, _, <$name as Signal>::Callback, _, _>(
97 self.callback_count.clone(),
98 |out| {
99 $crate::client::Client::signal().$req(out)
100 .block_on_tokio()
101 .expect("failed to request signal connection")
102 .into_inner()
103 },
104 $on_resp,
105 );
106
107 self.callback_sender.replace(channels.callback_sender);
108 self.dc_pinger.replace(channels.dc_pinger);
109 self.remove_callback_sender
110 .replace(channels.remove_callback_sender);
111 }
112 }
113 )*
114
115 $( #[$cfg_enum] )*
116 pub enum $enum {
117 $( $( #[$cfg] )* $renamed($cb),)*
118 }
119 )*};
120}
121
122signals! {
123 OutputSignal => {
125 OutputConnect = {
132 enum_name = Connect,
133 callback_type = SingleOutputFn,
134 client_request = output_connect,
135 on_response = |response, callbacks| {
136 let handle = OutputHandle { name: response.output_name };
137
138 for callback in callbacks {
139 callback(&handle);
140 }
141 },
142 }
143 OutputDisconnect = {
147 enum_name = Disconnect,
148 callback_type = SingleOutputFn,
149 client_request = output_disconnect,
150 on_response = |response, callbacks| {
151 let handle = OutputHandle { name: response.output_name };
152
153 for callback in callbacks {
154 callback(&handle);
155 }
156 },
157 }
158 OutputResize = {
162 enum_name = Resize,
163 callback_type = Box<dyn FnMut(&OutputHandle, u32, u32) + Send + 'static>,
164 client_request = output_resize,
165 on_response = |response, callbacks| {
166 let handle = OutputHandle { name: response.output_name };
167
168 for callback in callbacks {
169 callback(&handle, response.logical_width, response.logical_height)
170 }
171 },
172 }
173 OutputMove = {
177 enum_name = Move,
178 callback_type = Box<dyn FnMut(&OutputHandle, i32, i32) + Send + 'static>,
179 client_request = output_move,
180 on_response = |response, callbacks| {
181 let handle = OutputHandle { name: response.output_name };
182
183 for callback in callbacks {
184 callback(&handle, response.x, response.y)
185 }
186 },
187 }
188 }
189 WindowSignal => {
191 WindowPointerEnter = {
195 enum_name = PointerEnter,
196 callback_type = SingleWindowFn,
197 client_request = window_pointer_enter,
198 on_response = |response, callbacks| {
199 let handle = WindowHandle { id: response.window_id };
200
201 for callback in callbacks {
202 callback(&handle);
203 }
204 },
205 }
206 WindowPointerLeave = {
210 enum_name = PointerLeave,
211 callback_type = SingleWindowFn,
212 client_request = window_pointer_leave,
213 on_response = |response, callbacks| {
214 let handle = WindowHandle { id: response.window_id };
215
216 for callback in callbacks {
217 callback(&handle);
218 }
219 },
220 }
221 WindowFocused = {
225 enum_name = Focused,
226 callback_type = SingleWindowFn,
227 client_request = window_focused,
228 on_response = |response, callbacks| {
229 let handle = WindowHandle { id: response.window_id };
230
231 for callback in callbacks {
232 callback(&handle);
233 }
234 },
235 }
236 }
237 TagSignal => {
239 TagActive = {
241 enum_name = Active,
242 callback_type = Box<dyn FnMut(&TagHandle, bool) + Send + 'static>,
243 client_request = tag_active,
244 on_response = |response, callbacks| {
245 let handle = TagHandle { id: response.tag_id };
246
247 for callback in callbacks {
248 callback(&handle, response.active);
249 }
250 },
251 }
252 }
253 InputSignal => {
255 InputDeviceAdded = {
257 enum_name = DeviceAdded,
258 callback_type = Box<dyn FnMut(&DeviceHandle) + Send + 'static>,
259 client_request = input_device_added,
260 on_response = |response, callbacks| {
261 let handle = DeviceHandle { sysname: response.device_sysname };
262
263 for callback in callbacks {
264 callback(&handle);
265 }
266 },
267 }
268 }
269}
270
271pub(crate) type SingleOutputFn = Box<dyn FnMut(&OutputHandle) + Send + 'static>;
272pub(crate) type SingleWindowFn = Box<dyn FnMut(&WindowHandle) + Send + 'static>;
273
274pub(crate) struct SignalState {
275 pub(crate) output_connect: SignalData<OutputConnect>,
276 pub(crate) output_disconnect: SignalData<OutputDisconnect>,
277 pub(crate) output_resize: SignalData<OutputResize>,
278 pub(crate) output_move: SignalData<OutputMove>,
279
280 pub(crate) window_pointer_enter: SignalData<WindowPointerEnter>,
281 pub(crate) window_pointer_leave: SignalData<WindowPointerLeave>,
282 pub(crate) window_focused: SignalData<WindowFocused>,
283
284 pub(crate) tag_active: SignalData<TagActive>,
285
286 pub(crate) input_device_added: SignalData<InputDeviceAdded>,
287}
288
289impl std::fmt::Debug for SignalState {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 f.debug_struct("SignalState").finish()
292 }
293}
294
295impl SignalState {
296 pub(crate) fn new() -> Self {
297 Self {
298 output_connect: SignalData::new(),
299 output_disconnect: SignalData::new(),
300 output_resize: SignalData::new(),
301 output_move: SignalData::new(),
302 window_pointer_enter: SignalData::new(),
303 window_pointer_leave: SignalData::new(),
304 window_focused: SignalData::new(),
305 tag_active: SignalData::new(),
306 input_device_added: SignalData::new(),
307 }
308 }
309
310 pub(crate) fn shutdown(&mut self) {
311 self.output_connect.reset();
312 self.output_disconnect.reset();
313 self.output_resize.reset();
314 self.output_move.reset();
315 self.window_pointer_enter.reset();
316 self.window_pointer_leave.reset();
317 self.window_focused.reset();
318 self.tag_active.reset();
319 self.input_device_added.reset();
320 }
321}
322
323#[derive(Default, Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
324pub(crate) struct SignalConnId(pub(crate) u32);
325
326pub(crate) struct SignalData<S: Signal> {
327 callback_sender: Option<UnboundedSender<(SignalConnId, S::Callback)>>,
328 remove_callback_sender: Option<UnboundedSender<SignalConnId>>,
329 dc_pinger: Option<oneshot::Sender<()>>,
330 current_id: SignalConnId,
331 callback_count: Arc<AtomicU32>,
332}
333
334impl<S: Signal> SignalData<S> {
335 fn new() -> Self {
336 Self {
337 callback_sender: Default::default(),
338 remove_callback_sender: Default::default(),
339 dc_pinger: Default::default(),
340 current_id: Default::default(),
341 callback_count: Default::default(),
342 }
343 }
344}
345
346struct ConnectSignalChannels<F> {
347 callback_sender: UnboundedSender<(SignalConnId, F)>,
348 dc_pinger: oneshot::Sender<()>,
349 remove_callback_sender: UnboundedSender<SignalConnId>,
350}
351
352fn connect_signal<Req, Resp, F, T, O>(
353 callback_count: Arc<AtomicU32>,
354 to_in_stream: T,
355 mut on_response: O,
356) -> ConnectSignalChannels<F>
357where
358 Req: SignalRequest + Send + 'static,
359 Resp: Send + 'static,
360 F: Send + 'static,
361 T: FnOnce(UnboundedReceiverStream<Req>) -> Streaming<Resp>,
362 O: FnMut(Resp, btree_map::ValuesMut<'_, SignalConnId, F>) + Send + 'static,
363{
364 let (control_sender, recv) = unbounded_channel::<Req>();
365 let out_stream = UnboundedReceiverStream::new(recv);
366
367 let mut in_stream = to_in_stream(out_stream);
368
369 let (callback_sender, mut callback_recv) = unbounded_channel::<(SignalConnId, F)>();
370 let (remove_callback_sender, mut remove_callback_recv) = unbounded_channel::<SignalConnId>();
371 let (dc_pinger, mut dc_ping_recv) = oneshot::channel::<()>();
372
373 let signal_future = async move {
374 let mut callbacks = BTreeMap::<SignalConnId, F>::new();
375
376 control_sender
377 .send(Req::from_control(StreamControl::Ready))
378 .map_err(|err| {
379 println!("{err}");
380 err
381 })
382 .expect("send failed");
383
384 loop {
385 let in_stream_next = in_stream.next().fuse();
386 pin_mut!(in_stream_next);
387 let callback_recv_recv = callback_recv.recv().fuse();
388 pin_mut!(callback_recv_recv);
389 let remove_callback_recv_recv = remove_callback_recv.recv().fuse();
390 pin_mut!(remove_callback_recv_recv);
391 let mut dc_ping_recv_fuse = (&mut dc_ping_recv).fuse();
392
393 futures::select! {
394 response = in_stream_next => {
395 let Some(response) = response else { continue };
396
397 match response {
398 Ok(response) => {
399 on_response(response, callbacks.values_mut());
400
401 control_sender
402 .send(Req::from_control(StreamControl::Ready))
403 .expect("send failed");
404
405 tokio::task::yield_now().await;
406 }
407 Err(status) => eprintln!("Error in recv: {status}"),
408 }
409 }
410 callback = callback_recv_recv => {
411 if let Some((id, callback)) = callback {
412 callbacks.insert(id, callback);
413 }
416 }
417 remove = remove_callback_recv_recv => {
418 if let Some(id) = remove {
419 if callbacks.remove(&id).is_some() {
420 assert!(callback_count.fetch_sub(1, Ordering::SeqCst) > 0);
421 }
422 if callbacks.is_empty() {
423 assert!(callback_count.load(Ordering::SeqCst) == 0);
424 control_sender.send(Req::from_control(StreamControl::Disconnect)).expect("send failed");
425 break;
426 }
427 }
428 }
429 _dc = dc_ping_recv_fuse => {
430 let _ = control_sender.send(Req::from_control(StreamControl::Disconnect));
431 break;
432 }
433 }
434 }
435 };
436
437 tokio::spawn(signal_future);
438
439 ConnectSignalChannels {
440 callback_sender,
441 dc_pinger,
442 remove_callback_sender,
443 }
444}
445
446pub struct SignalHandle {
450 id: SignalConnId,
451 remove_callback_sender: UnboundedSender<SignalConnId>,
452}
453
454impl SignalHandle {
455 pub(crate) fn new(
456 id: SignalConnId,
457 remove_callback_sender: UnboundedSender<SignalConnId>,
458 ) -> Self {
459 Self {
460 id,
461 remove_callback_sender,
462 }
463 }
464
465 pub fn disconnect(self) {
467 self.remove_callback_sender
468 .send(self.id)
469 .expect("failed to disconnect signal");
470 }
471}