twizzler_io/
pty.rs

1use std::{
2    cell::UnsafeCell,
3    io::{ErrorKind, Read, Write},
4    sync::{
5        Arc, Mutex,
6        atomic::{AtomicU64, Ordering},
7    },
8};
9
10use libc::{
11    _POSIX_VDISABLE, B9600, BRKINT, CREAD, CS7, CS8, ECHO, ECHOCTL, ECHOE, ECHOK, ECHOKE, ECHONL,
12    HUPCL, ICANON, ICRNL, IEXTEN, IGNCR, IMAXBEL, INLCR, ISIG, ISTRIP, IXANY, IXON, OCRNL, ONLCR,
13    OPOST, PARENB, VEOF, VERASE, VINTR, VKILL, VQUIT, VSTATUS, VWERASE, XTABS,
14};
15use memchr::{memchr2, memchr3, memrchr, memrchr3};
16use twizzler::{
17    BaseType, Invariant,
18    object::{MapFlags, ObjID, Object, ObjectBuilder, TypedObject},
19};
20use twizzler_abi::syscall::{
21    ObjectCreate, ThreadSync, ThreadSyncFlags, ThreadSyncOp, ThreadSyncReference, ThreadSyncSleep,
22    ThreadSyncWake, sys_thread_sync,
23};
24
25use crate::buffer::VolatileBuffer;
26
27pub const BUF_SZ: usize = 8192;
28
29fn do_sleep(sync: ThreadSyncSleep) -> std::io::Result<()> {
30    sys_thread_sync(&mut [ThreadSync::new_sleep(sync)], None)?;
31    Ok(())
32}
33
34#[derive(Clone)]
35struct PtyInputReader {
36    pty: Object<PtyBase>,
37}
38
39impl Read for PtyInputReader {
40    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
41        let count = self.pty.base().client_input.read_bytes(buf)?;
42        if count == 0 && buf.len() > 0 {
43            return Err(ErrorKind::WouldBlock.into());
44        }
45        Ok(count)
46    }
47}
48
49#[derive(Clone)]
50struct PtyOutputWriter {
51    pty: Object<PtyBase>,
52}
53
54impl Write for PtyOutputWriter {
55    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
56        let count = self.pty.base().client_output.write_bytes(buf)?;
57        Ok(count)
58    }
59
60    fn flush(&mut self) -> std::io::Result<()> {
61        Ok(())
62    }
63}
64
65#[derive(Clone)]
66struct PtyOutputReader {
67    pty: Object<PtyBase>,
68}
69
70impl PtyOutputReader {
71    fn read(&self, buf: &mut [u8]) -> std::io::Result<usize> {
72        let count = self.pty.base().client_output.read_bytes(buf)?;
73        Ok(count)
74    }
75}
76
77pub struct PtyClientHandle {
78    input: Arc<Mutex<InputConverter<PtyInputReader>>>,
79    output: Arc<Mutex<OutputConverter<PtyOutputWriter>>>,
80    termios_gen: AtomicU64,
81    pty: Object<PtyBase>,
82}
83
84impl Clone for PtyClientHandle {
85    fn clone(&self) -> Self {
86        Self {
87            input: self.input.clone(),
88            output: self.output.clone(),
89            termios_gen: AtomicU64::new(self.termios_gen.load(Ordering::SeqCst)),
90            pty: self.pty.clone(),
91        }
92    }
93}
94
95impl PtyClientHandle {
96    pub fn new(id: ObjID) -> std::io::Result<Self> {
97        let obj =
98            unsafe { Object::<PtyBase>::map_unchecked(id, MapFlags::READ | MapFlags::WRITE) }?;
99        let (termios, termios_gen) = obj.base().read_termios();
100        Ok(Self {
101            input: Arc::new(Mutex::new(InputConverter::new(
102                termios,
103                PtyInputReader { pty: obj.clone() },
104            ))),
105            output: Arc::new(Mutex::new(OutputConverter::new(
106                termios,
107                PtyOutputWriter { pty: obj.clone() },
108            ))),
109            termios_gen: AtomicU64::new(termios_gen),
110            pty: obj,
111        })
112    }
113
114    fn update_termios(&self) {
115        if let Some((termios, termios_gen)) = self
116            .pty
117            .base()
118            .try_read_termios(self.termios_gen.load(Ordering::SeqCst))
119        {
120            self.input.lock().unwrap().termios = termios;
121            self.output.lock().unwrap().termios = termios;
122            self.termios_gen.store(termios_gen, Ordering::SeqCst);
123        }
124    }
125
126    pub fn set_termios(&self, termios: libc::termios) {
127        self.pty.base().update_termios(|_| termios);
128    }
129
130    pub fn set_winsize(&self, winsize: libc::winsize) {
131        self.pty.base().update_winsize(|_| winsize);
132    }
133}
134
135#[derive(Clone)]
136struct PtyInputPoster {
137    pty: Object<PtyBase>,
138}
139
140impl Write for PtyInputPoster {
141    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
142        let count = self.pty.base().client_input.write_bytes(buf)?;
143        Ok(count)
144    }
145
146    fn flush(&mut self) -> std::io::Result<()> {
147        Ok(())
148    }
149}
150
151pub struct PtyServerHandle {
152    client_input: Arc<Mutex<InputPoster<PtyInputPoster, PtyOutputWriter>>>,
153    client_output: PtyOutputReader,
154    termios_gen: AtomicU64,
155    signal_handler: Option<fn(&PtyServerHandle, PtySignal)>,
156}
157
158impl Write for PtyServerHandle {
159    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
160        self.write_b(buf)
161    }
162
163    fn flush(&mut self) -> std::io::Result<()> {
164        self.flush_b()
165    }
166}
167
168impl Read for PtyServerHandle {
169    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
170        self.read_b(buf)
171    }
172}
173
174impl Write for PtyClientHandle {
175    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
176        self.write_b(buf)
177    }
178
179    fn flush(&mut self) -> std::io::Result<()> {
180        self.flush_b()
181    }
182}
183
184impl Read for PtyClientHandle {
185    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
186        self.read_b(buf)
187    }
188}
189
190impl Clone for PtyServerHandle {
191    fn clone(&self) -> Self {
192        Self {
193            client_input: self.client_input.clone(),
194            client_output: self.client_output.clone(),
195            termios_gen: AtomicU64::new(self.termios_gen.load(Ordering::SeqCst)),
196            signal_handler: self.signal_handler,
197        }
198    }
199}
200
201impl PtyServerHandle {
202    pub fn new(
203        id: ObjID,
204        signal_handler: Option<fn(&PtyServerHandle, PtySignal)>,
205    ) -> std::io::Result<Self> {
206        let obj =
207            unsafe { Object::<PtyBase>::map_unchecked(id, MapFlags::READ | MapFlags::WRITE) }?;
208        let (termios, termios_gen) = obj.base().read_termios();
209        Ok(Self {
210            client_input: Arc::new(Mutex::new(InputPoster::new(
211                termios,
212                PtyInputPoster { pty: obj.clone() },
213                PtyOutputWriter { pty: obj.clone() },
214            ))),
215            termios_gen: AtomicU64::new(termios_gen),
216            client_output: PtyOutputReader { pty: obj },
217            signal_handler,
218        })
219    }
220
221    pub fn object(&self) -> &Object<PtyBase> {
222        &self.client_output.pty
223    }
224
225    fn update_termios(&self) {
226        if let Some((termios, termios_gen)) = self
227            .client_output
228            .pty
229            .base()
230            .try_read_termios(self.termios_gen.load(Ordering::SeqCst))
231        {
232            self.client_input.lock().unwrap().termios = termios;
233            self.termios_gen.store(termios_gen, Ordering::SeqCst);
234        }
235    }
236
237    pub fn set_termios(&self, termios: libc::termios) {
238        self.client_output.pty.base().update_termios(|_| termios);
239    }
240
241    pub fn set_winsize(&self, winsize: libc::winsize) {
242        let old = self.client_output.pty.base().read_winsize().0;
243        if old.ws_row != winsize.ws_row || old.ws_col != winsize.ws_col || old.ws_xpixel != winsize.ws_xpixel || old.ws_ypixel != winsize.ws_ypixel {
244            self.client_output.pty.base().update_winsize(|_| winsize);
245            if let Some(signal_handler) = self.signal_handler {
246                (signal_handler)(self, PtySignal::Winch);
247            }
248        }
249    }
250
251    pub fn waitpoint(&self, write: bool) -> ThreadSyncSleep {
252        if write {
253            self.client_output
254                .pty
255                .base()
256                .client_input
257                .sync_for_avail_space()
258        } else {
259            self.client_output
260                .pty
261                .base()
262                .client_output
263                .sync_for_pending_data()
264        }
265    }
266
267    pub fn is_ready(&self, write: bool) -> bool {
268        if write {
269            self.client_output.pty.base().client_input.avail_space() > 0
270        } else {
271            !self.client_output.pty.base().client_output.is_empty()
272        }
273    }
274}
275
276impl PtyServerHandle {
277    pub fn write_nb(&mut self, buf: &[u8]) -> std::io::Result<usize> {
278        self.update_termios();
279        let report = self.client_input.lock().unwrap().write_input(buf)?;
280        if let Some(signal) = report.posted_signal
281            && let Some(signal_handler) = self.signal_handler
282        {
283            (signal_handler)(self, signal);
284        }
285        if report.consumed == 0 && buf.len() > 0 {
286            return Err(ErrorKind::WouldBlock.into());
287        }
288        Ok(report.consumed)
289    }
290
291    pub fn read_nb(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
292        self.update_termios();
293        let count = self.client_output.read(buf)?;
294        if count == 0 && buf.len() > 0 {
295            return Err(ErrorKind::WouldBlock.into());
296        }
297        Ok(count)
298    }
299}
300
301impl PtyServerHandle {
302    pub fn get_termios(&self) -> libc::termios {
303        self.client_output.pty.base().read_termios().0
304    }
305
306    pub fn get_winsize(&self) -> libc::winsize {
307        self.client_output.pty.base().read_winsize().0
308    }
309
310    pub fn write_b(&self, buf: &[u8]) -> std::io::Result<usize> {
311        self.update_termios();
312        let sync = self
313            .client_output
314            .pty
315            .base()
316            .client_input
317            .sync_for_avail_space();
318        let report = self.client_input.lock().unwrap().write_input(buf)?;
319        if let Some(signal) = report.posted_signal
320            && let Some(signal_handler) = self.signal_handler
321        {
322            (signal_handler)(self, signal);
323        }
324        if report.consumed == 0 && buf.len() > 0 {
325            if !self.is_ready(true) {
326                do_sleep(sync)?;
327            }
328            return self.write_b(buf);
329        }
330        Ok(report.consumed)
331    }
332
333    pub fn flush_b(&mut self) -> std::io::Result<()> {
334        Ok(())
335    }
336}
337
338impl PtyServerHandle {
339    pub fn read_b(&self, buf: &mut [u8]) -> std::io::Result<usize> {
340        self.update_termios();
341        let sync = self
342            .client_output
343            .pty
344            .base()
345            .client_output
346            .sync_for_pending_data();
347        let count = self.client_output.read(buf)?;
348        if count == 0 && buf.len() > 0 {
349            if !self.is_ready(false) {
350                do_sleep(sync)?;
351            }
352            return self.read_b(buf);
353        }
354        Ok(count)
355    }
356}
357
358impl PtyClientHandle {
359    pub fn waitpoint(&self, write: bool) -> ThreadSyncSleep {
360        if write {
361            self.pty.base().client_output.sync_for_avail_space()
362        } else {
363            self.pty.base().client_input.sync_for_pending_data()
364        }
365    }
366
367    pub fn is_ready(&self, write: bool) -> bool {
368        if write {
369            self.pty.base().client_output.avail_space() > 0
370        } else {
371            !self.pty.base().client_input.is_empty()
372        }
373    }
374
375    pub fn write_nb(&mut self, buf: &[u8]) -> std::io::Result<usize> {
376        self.update_termios();
377        let count = self.output.lock().unwrap().write(buf)?;
378        if count == 0 && buf.len() > 0 {
379            return Err(ErrorKind::WouldBlock.into());
380        }
381        Ok(count)
382    }
383
384    pub fn read_nb(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
385        self.update_termios();
386        let res = self.input.lock().unwrap().read(buf);
387        match res {
388            Ok(c) => Ok(c),
389            Err(e) if e.kind() != ErrorKind::WouldBlock => Err(e),
390            _ => {
391                if buf.len() == 0 {
392                    return Ok(0);
393                }
394                Err(ErrorKind::WouldBlock.into())
395            }
396        }
397    }
398}
399
400impl PtyClientHandle {
401    pub fn object(&self) -> Object<PtyBase> {
402        self.output.lock().unwrap().writer.pty.clone()
403    }
404
405    pub fn write_b(&self, buf: &[u8]) -> std::io::Result<usize> {
406        self.update_termios();
407        let sync = self.pty.base().client_output.sync_for_avail_space();
408        let count = self.output.lock().unwrap().write(buf)?;
409        if count == 0 && buf.len() > 0 {
410            if !self.is_ready(true) {
411                do_sleep(sync)?;
412            }
413            return self.write_b(buf);
414        }
415        Ok(count)
416    }
417
418    pub fn flush_b(&mut self) -> std::io::Result<()> {
419        self.update_termios();
420        self.output.lock().unwrap().flush()
421    }
422}
423
424impl PtyClientHandle {
425    pub fn get_termios(&self) -> libc::termios {
426        self.pty.base().read_termios().0
427    }
428
429    pub fn get_winsize(&self) -> libc::winsize {
430        self.pty.base().read_winsize().0
431    }
432
433    pub fn read_b(&self, buf: &mut [u8]) -> std::io::Result<usize> {
434        self.update_termios();
435        let sync = self.pty.base().client_input.sync_for_pending_data();
436        let res = self.input.lock().unwrap().read(buf);
437        match res {
438            Ok(c) => Ok(c),
439            Err(e) if e.kind() != ErrorKind::WouldBlock => Err(e),
440            _ => {
441                if buf.len() == 0 {
442                    return Ok(0);
443                }
444                if !self.is_ready(false) {
445                    do_sleep(sync)?;
446                }
447                self.read_b(buf)
448            }
449        }
450    }
451}
452
453#[derive(Invariant, BaseType)]
454pub struct PtyBase {
455    termios_gen: AtomicU64,
456    termios: UnsafeCell<libc::termios>,
457    winsize_gen: AtomicU64,
458    winsize: UnsafeCell<libc::winsize>,
459    client_input: VolatileBuffer<BUF_SZ>,
460    client_output: VolatileBuffer<BUF_SZ>,
461}
462
463unsafe impl Send for PtyBase {}
464unsafe impl Sync for PtyBase {}
465
466const fn ctrl(x: u8) -> u8 {
467    x & 0o37
468}
469
470const CEOF: u8 = ctrl(b'd');
471const CEOL: u8 = _POSIX_VDISABLE;
472const CERASE: u8 = 127;
473const CINTR: u8 = ctrl(b'c');
474const CSTATUS: u8 = ctrl(b't');
475const CKILL: u8 = ctrl(b'u');
476const CMIN: u8 = 1;
477const CQUIT: u8 = 0o034; // FS, ^\
478const CSUSP: u8 = ctrl(b'z');
479const CTIME: u8 = 0;
480const _CDSUSP: u8 = ctrl(b'y');
481const CSTART: u8 = ctrl(b'q');
482const CSTOP: u8 = ctrl(b's');
483const CLNEXT: u8 = ctrl(b'v');
484const CDISCARD: u8 = ctrl(b'o');
485const CWERASE: u8 = ctrl(b'w');
486const CREPRINT: u8 = ctrl(b'r');
487const _CEOT: u8 = CEOF;
488const _CBRK: u8 = CEOL;
489const _CRPRNT: u8 = CREPRINT;
490const _CFLUSH: u8 = CDISCARD;
491
492pub const DEFAULT_TERMIOS: libc::termios = libc::termios {
493    c_iflag: BRKINT | ISTRIP | ICRNL | IMAXBEL | IXON | IXANY,
494    c_oflag: OPOST | ONLCR | XTABS,
495    c_cflag: CREAD | CS7 | PARENB | HUPCL,
496    c_lflag: ECHO | ICANON | ISIG | IEXTEN | ECHOE | ECHOKE | ECHOCTL,
497    c_cc: [
498        CINTR,
499        CQUIT,
500        CERASE,
501        CKILL,
502        CEOF,
503        CTIME,
504        CMIN,
505        _POSIX_VDISABLE,
506        CSTART,
507        CSTOP,
508        CSUSP,
509        CEOL,
510        CREPRINT,
511        CDISCARD,
512        CWERASE,
513        CLNEXT,
514        _POSIX_VDISABLE,
515        _POSIX_VDISABLE,
516        CSTATUS,
517        _POSIX_VDISABLE,
518        _POSIX_VDISABLE,
519        _POSIX_VDISABLE,
520        _POSIX_VDISABLE,
521        _POSIX_VDISABLE,
522        _POSIX_VDISABLE,
523        _POSIX_VDISABLE,
524        _POSIX_VDISABLE,
525        _POSIX_VDISABLE,
526        _POSIX_VDISABLE,
527        _POSIX_VDISABLE,
528        _POSIX_VDISABLE,
529        _POSIX_VDISABLE,
530    ],
531    __c_ispeed: B9600,
532    __c_ospeed: B9600,
533    c_line: 0,
534};
535
536pub const DEFAULT_TERMIOS_RAW: libc::termios = libc::termios {
537    c_iflag: 0,
538    c_oflag: 0,
539    c_cflag: CREAD | CS8,
540    c_lflag: 0,
541    c_cc: [
542        CINTR,
543        CQUIT,
544        CERASE,
545        CKILL,
546        CEOF,
547        CTIME,
548        CMIN,
549        _POSIX_VDISABLE,
550        CSTART,
551        CSTOP,
552        CSUSP,
553        CEOL,
554        CREPRINT,
555        CDISCARD,
556        CWERASE,
557        CLNEXT,
558        _POSIX_VDISABLE,
559        _POSIX_VDISABLE,
560        CSTATUS,
561        _POSIX_VDISABLE,
562        _POSIX_VDISABLE,
563        _POSIX_VDISABLE,
564        _POSIX_VDISABLE,
565        _POSIX_VDISABLE,
566        _POSIX_VDISABLE,
567        _POSIX_VDISABLE,
568        _POSIX_VDISABLE,
569        _POSIX_VDISABLE,
570        _POSIX_VDISABLE,
571        _POSIX_VDISABLE,
572        _POSIX_VDISABLE,
573        _POSIX_VDISABLE,
574    ],
575    __c_ispeed: B9600,
576    __c_ospeed: B9600,
577    c_line: 0,
578};
579
580impl PtyBase {
581    pub fn new(termios: libc::termios) -> Self {
582        Self {
583            termios_gen: AtomicU64::new(0),
584            termios: UnsafeCell::new(termios),
585            winsize_gen: AtomicU64::new(0),
586            winsize: UnsafeCell::new(libc::winsize { ws_row: 0, ws_col: 0, ws_xpixel: 0, ws_ypixel: 0 }),
587            client_input: VolatileBuffer::new(),
588            client_output: VolatileBuffer::new(),
589        }
590    }
591
592    pub fn create_object(
593        spec: ObjectCreate,
594        termios: libc::termios,
595    ) -> std::io::Result<Object<Self>> {
596        let obj = ObjectBuilder::new(spec).build(PtyBase::new(termios))?;
597        Ok(obj)
598    }
599
600    pub fn update_termios(
601        &self,
602        mut f: impl FnMut(libc::termios) -> libc::termios,
603    ) -> libc::termios {
604        loop {
605            let current_gen = self.termios_gen.load(std::sync::atomic::Ordering::Acquire);
606
607            // If someone else has the write lock, wait and retry.
608            if current_gen & 1 != 0 {
609                self.do_sleep_for_termios_gen(current_gen);
610                continue;
611            }
612            if self
613                .termios_gen
614                .compare_exchange(
615                    current_gen,
616                    current_gen + 1,
617                    Ordering::SeqCst,
618                    Ordering::SeqCst,
619                )
620                .is_ok()
621            {
622                // We now have the write lock.
623                let termios = unsafe { self.termios.get().read() };
624                let new_termios = f(termios);
625                unsafe { self.termios.get().write(new_termios) };
626                self.termios_gen
627                    .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
628                self.do_wake_for_termios_gen();
629                return new_termios;
630            }
631        }
632    }
633
634    fn do_wake_for_termios_gen(&self) {
635        let _ = twizzler_abi::syscall::sys_thread_sync(
636            &mut [ThreadSync::new_wake(ThreadSyncWake::new(
637                ThreadSyncReference::Virtual(&self.termios_gen),
638                usize::MAX,
639            ))],
640            None,
641        )
642        .inspect_err(|e| tracing::error!("failed to wake on termios for pty: {}", e));
643    }
644
645    fn do_sleep_for_termios_gen(&self, generation: u64) {
646        let _ = twizzler_abi::syscall::sys_thread_sync(
647            &mut [ThreadSync::new_sleep(ThreadSyncSleep::new(
648                ThreadSyncReference::Virtual(&self.termios_gen),
649                generation,
650                ThreadSyncOp::Equal,
651                ThreadSyncFlags::empty(),
652            ))],
653            None,
654        )
655        .inspect_err(|e| tracing::error!("failed to wait on termios for pty: {}", e));
656    }
657
658    pub fn try_read_termios(&self, current: u64) -> Option<(libc::termios, u64)> {
659        let current_gen = self.termios_gen.load(std::sync::atomic::Ordering::Acquire);
660        if current == current_gen {
661            return None;
662        }
663        let val = unsafe { self.termios.get().read() };
664        let after_gen = self.termios_gen.load(std::sync::atomic::Ordering::SeqCst);
665
666        if current_gen == after_gen && current_gen & 1 == 0 {
667            return Some((val, current_gen));
668        }
669        None
670    }
671
672    pub fn read_termios(&self) -> (libc::termios, u64) {
673        loop {
674            let current_gen = self.termios_gen.load(std::sync::atomic::Ordering::Acquire);
675            let val = unsafe { self.termios.get().read() };
676            let after_gen = self.termios_gen.load(std::sync::atomic::Ordering::SeqCst);
677
678            if current_gen == after_gen && current_gen & 1 == 0 {
679                return (val, current_gen);
680            }
681            self.do_sleep_for_termios_gen(after_gen);
682        }
683    }
684
685    pub fn wait_termios(&self, generation: u64) -> u64 {
686        let g = self.termios_gen.load(std::sync::atomic::Ordering::SeqCst);
687        if g != generation {
688            return g;
689        }
690        self.do_sleep_for_termios_gen(generation);
691        self.termios_gen.load(std::sync::atomic::Ordering::SeqCst)
692    }
693
694    pub fn update_winsize(
695        &self,
696        mut f: impl FnMut(libc::winsize) -> libc::winsize,
697    ) -> libc::winsize {
698        loop {
699            let current_gen = self.winsize_gen.load(std::sync::atomic::Ordering::Acquire);
700
701            if current_gen & 1 != 0 {
702                self.do_sleep_for_winsize_gen(current_gen);
703                continue;
704            }
705            if self
706                .winsize_gen
707                .compare_exchange(
708                    current_gen,
709                    current_gen + 1,
710                    Ordering::SeqCst,
711                    Ordering::SeqCst,
712                )
713                .is_ok()
714            {
715                let winsize = unsafe { self.winsize.get().read() };
716                let new_winsize = f(winsize);
717                unsafe { self.winsize.get().write(new_winsize) };
718                self.winsize_gen
719                    .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
720                self.do_wake_for_winsize_gen();
721                return new_winsize;
722            }
723        }
724    }
725
726    fn do_wake_for_winsize_gen(&self) {
727        let _ = twizzler_abi::syscall::sys_thread_sync(
728            &mut [ThreadSync::new_wake(ThreadSyncWake::new(
729                ThreadSyncReference::Virtual(&self.winsize_gen),
730                usize::MAX,
731            ))],
732            None,
733        )
734        .inspect_err(|e| tracing::error!("failed to wake on winsize for pty: {}", e));
735    }
736
737    fn do_sleep_for_winsize_gen(&self, generation: u64) {
738        let _ = twizzler_abi::syscall::sys_thread_sync(
739            &mut [ThreadSync::new_sleep(ThreadSyncSleep::new(
740                ThreadSyncReference::Virtual(&self.winsize_gen),
741                generation,
742                ThreadSyncOp::Equal,
743                ThreadSyncFlags::empty(),
744            ))],
745            None,
746        )
747        .inspect_err(|e| tracing::error!("failed to wait on winsize for pty: {}", e));
748    }
749
750    pub fn try_read_winsize(&self, current: u64) -> Option<(libc::winsize, u64)> {
751        let current_gen = self.winsize_gen.load(std::sync::atomic::Ordering::Acquire);
752        if current == current_gen {
753            return None;
754        }
755        let val = unsafe { self.winsize.get().read() };
756        let after_gen = self.winsize_gen.load(std::sync::atomic::Ordering::SeqCst);
757
758        if current_gen == after_gen && current_gen & 1 == 0 {
759            return Some((val, current_gen));
760        }
761        None
762    }
763
764    pub fn read_winsize(&self) -> (libc::winsize, u64) {
765        loop {
766            let current_gen = self.winsize_gen.load(std::sync::atomic::Ordering::Acquire);
767            let val = unsafe { self.winsize.get().read() };
768            let after_gen = self.winsize_gen.load(std::sync::atomic::Ordering::SeqCst);
769
770            if current_gen == after_gen && current_gen & 1 == 0 {
771                return (val, current_gen);
772            }
773            self.do_sleep_for_winsize_gen(after_gen);
774        }
775    }
776}
777
778#[derive(Clone)]
779pub struct InputPoster<W: Write, E: Write> {
780    termios: libc::termios,
781    writer: W,
782    echoer: E,
783    echobuf: [u8; BUF_SZ],
784    echobuf_len: usize,
785}
786
787#[derive(Debug, Clone, Copy, PartialEq, Eq)]
788pub enum PtySignal {
789    Interrupt,
790    Quit,
791    Status,
792    Winch,
793}
794
795#[derive(Debug, Clone, Copy, PartialEq, Eq)]
796pub struct WriteReport {
797    pub consumed: usize,
798    pub posted_signal: Option<PtySignal>,
799}
800
801impl<W: Write, E: Write> InputPoster<W, E> {
802    pub fn new(termios: libc::termios, writer: W, echoer: E) -> Self {
803        Self {
804            termios,
805            writer,
806            echoer,
807            echobuf: [0; _],
808            echobuf_len: 0,
809        }
810    }
811
812    fn maybe_echo(&mut self, mut buf: &[u8]) -> std::io::Result<()> {
813        let echo = self.termios.c_lflag & ECHO != 0;
814        let echoe = self.termios.c_lflag & ECHOE != 0 && self.termios.c_lflag & ICANON != 0;
815        let echok = self.termios.c_lflag & ECHOK != 0 && self.termios.c_lflag & ICANON != 0;
816        let echonl = self.termios.c_lflag & ECHONL != 0 && self.termios.c_lflag & ICANON != 0;
817
818        if !echo && !echonl {
819            return Ok(());
820        }
821
822        if !echo {
823            self.echobuf_len = 0;
824            for _ in 0..buf.iter().filter(|p| **p == b'\n').count() {
825                self.echoer.write_all(&[b'\n'])?;
826            }
827            return Ok(());
828        }
829
830        while buf.len() > 0 {
831            // If we overrun the buffer, give up.
832            if self.echobuf_len == BUF_SZ {
833                self.echobuf_len = 0;
834            }
835
836            let thislen = (BUF_SZ - self.echobuf_len).min(buf.len());
837            self.echobuf[self.echobuf_len..(self.echobuf_len + thislen)]
838                .copy_from_slice(&buf[0..thislen]);
839
840            let mut cur_echo_off = self.echobuf_len;
841            self.echobuf_len += thislen;
842
843            while cur_echo_off < self.echobuf_len {
844                let echobuf = &self.echobuf[cur_echo_off..self.echobuf_len];
845                let erase_idx = memchr3(CERASE, CKILL, CWERASE, echobuf);
846                let nl_idx = memchr::memchr(b'\n', echobuf);
847                let min_idx = if let Some(e) = erase_idx
848                    && let Some(n) = nl_idx
849                {
850                    Some(e.min(n))
851                } else {
852                    erase_idx.or(nl_idx)
853                };
854
855                let erase_chars = |this: &mut Self, erase_start: usize, erase_char: usize| {
856                    this.echobuf.copy_within((erase_char + 1).., erase_start);
857                    this.echobuf_len = this
858                        .echobuf_len
859                        .saturating_sub((erase_char + 1) - erase_start);
860                };
861
862                let echolen = if let Some(idx) = min_idx {
863                    if idx > 0 {
864                        self.echoer.write_all(&echobuf[0..idx])?;
865                    }
866                    match echobuf[idx] {
867                        CERASE if echoe => {
868                            self.echoer.write_all(&[8, b' ', 8])?;
869                            erase_chars(
870                                self,
871                                (cur_echo_off + idx).saturating_sub(1),
872                                cur_echo_off + idx,
873                            );
874                        }
875                        CKILL if echok => {
876                            let idx = idx + cur_echo_off;
877                            let space = memrchr(b'\n', &self.echobuf[0..idx]).unwrap_or(0);
878                            for _ in 0..(idx.saturating_sub(space + 1)).max(1) {
879                                self.echoer.write_all(&[8, b' ', 8])?;
880                            }
881                            if space + 1 == idx {
882                                erase_chars(self, space, idx);
883                            } else {
884                                erase_chars(self, space + 1, idx);
885                            }
886                        }
887                        CWERASE if echoe => {
888                            let idx = idx + cur_echo_off;
889                            let space =
890                                memrchr3(b'\n', b'\t', b' ', &self.echobuf[0..idx]).unwrap_or(0);
891                            for _ in 0..(idx.saturating_sub(space + 1)).max(1) {
892                                self.echoer.write_all(&[8, b' ', 8])?;
893                            }
894                            if space + 1 == idx {
895                                erase_chars(self, space, idx);
896                            } else {
897                                erase_chars(self, space + 1, idx);
898                            }
899                        }
900                        b'\n' => {
901                            self.echoer.write_all(&[echobuf[idx]])?;
902                            self.echobuf_len = 0;
903                        }
904                        _ => {
905                            self.echoer.write_all(&[echobuf[idx]])?;
906                        }
907                    }
908                    idx + 1
909                } else {
910                    self.echoer.write_all(echobuf)?;
911                    echobuf.len()
912                };
913                cur_echo_off += echolen;
914            }
915
916            buf = &buf[thislen..];
917        }
918        Ok(())
919    }
920
921    pub fn write_input(&mut self, mut buf: &[u8]) -> std::io::Result<WriteReport> {
922        let vintr = self.termios.c_cc[VINTR];
923        let vquit = self.termios.c_cc[VQUIT];
924        let vstatus = self.termios.c_cc[VSTATUS];
925
926        let mut total = 0;
927        let mut sig = None;
928
929        while buf.len() > 0 && sig.is_none() {
930            let (count, skip) = if let Some(idx) = memchr3(vstatus, vintr, vquit, buf) {
931                match buf[idx] {
932                    c if c == vintr => sig = Some(PtySignal::Interrupt),
933                    c if c == vquit => sig = Some(PtySignal::Quit),
934                    c if c == vstatus => sig = Some(PtySignal::Status),
935                    _ => unreachable!(),
936                }
937                (idx, true)
938            } else {
939                (buf.len(), false)
940            };
941
942            let wcount = self.writer.write(&buf[0..count])?;
943            let mut ecount = 0;
944            while ecount < wcount {
945                let mut echobuf = [0; BUF_SZ];
946                let remaining = BUF_SZ.min(wcount - ecount);
947                echobuf[0..remaining].copy_from_slice(&buf[ecount..wcount]);
948                let c = input_map(&self.termios, &mut echobuf[0..remaining]);
949                self.maybe_echo(&echobuf[0..c])?;
950                ecount += c;
951            }
952
953            total += wcount;
954            buf = &buf[wcount..];
955            if skip && wcount == count {
956                total += 1;
957                buf = &buf[1..];
958            }
959        }
960
961        Ok(WriteReport {
962            consumed: total,
963            posted_signal: sig,
964        })
965    }
966}
967
968#[derive(Clone)]
969pub struct OutputConverter<W: Write> {
970    termios: libc::termios,
971    writer: W,
972}
973
974impl<W: Write> OutputConverter<W> {
975    pub fn new(termios: libc::termios, writer: W) -> Self {
976        Self { termios, writer }
977    }
978
979    pub fn write_bytes_simple(&mut self, buf: &[u8]) -> std::io::Result<usize> {
980        self.writer.write(buf)
981    }
982
983    pub fn write_bytes_processed(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
984        let cr_to_nl = self.termios.c_oflag & OCRNL != 0;
985        let nl_to_crnl = self.termios.c_oflag & ONLCR != 0;
986
987        if !cr_to_nl && !nl_to_crnl {
988            return self.write_bytes_simple(buf);
989        }
990
991        let mut total = 0;
992        while buf.len() > 0 {
993            let (count, extra) = if let Some(idx) = memchr2(b'\r', b'\n', buf) {
994                match buf[idx] {
995                    b'\r' if cr_to_nl => {
996                        if nl_to_crnl {
997                            (idx, Some(b"\r\n" as &[u8]))
998                        } else {
999                            (idx, Some(b"\n" as &[u8]))
1000                        }
1001                    }
1002                    b'\n' if nl_to_crnl => (idx, Some(b"\r\n" as &[u8])),
1003                    _ => (idx + 1, None),
1004                }
1005            } else {
1006                (buf.len(), None)
1007            };
1008            let thiswrite = self.writer.write(&buf[0..count])?;
1009            total += thiswrite;
1010            buf = &buf[thiswrite..];
1011            if let Some(extra) = extra {
1012                self.writer.write_all(extra)?;
1013                // Note: we only increment by 1 here because regardless of the extra
1014                // data we write, it came from 1 byte of the input buffer.
1015                total += 1;
1016                buf = &buf[1..];
1017            }
1018        }
1019
1020        Ok(total)
1021    }
1022}
1023
1024impl<W: Write> Write for OutputConverter<W> {
1025    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1026        if self.termios.c_oflag & OPOST != 0 {
1027            self.write_bytes_processed(buf)
1028        } else {
1029            self.write_bytes_simple(buf)
1030        }
1031    }
1032
1033    fn flush(&mut self) -> std::io::Result<()> {
1034        self.writer.flush()
1035    }
1036}
1037
1038#[derive(Clone)]
1039pub struct InputConverter<R: Read> {
1040    termios: libc::termios,
1041    linebuf: [u8; BUF_SZ],
1042    linebuf_count: usize,
1043    reader: R,
1044}
1045
1046impl<R: Read> InputConverter<R> {
1047    pub fn new(termios: libc::termios, reader: R) -> Self {
1048        Self {
1049            termios,
1050            reader,
1051            linebuf_count: 0,
1052            linebuf: [0; BUF_SZ],
1053        }
1054    }
1055
1056    pub fn update_termios(&mut self, termios: libc::termios) {
1057        self.termios = termios;
1058    }
1059
1060    fn refill_linebuf(&mut self) -> std::io::Result<()> {
1061        let linebuf = &mut self.linebuf[self.linebuf_count..];
1062        let count = self.reader.read(linebuf)?;
1063        let count = input_map(&self.termios, &mut linebuf[..count]);
1064
1065        let verase = self.termios.c_cc[VERASE];
1066        let vwerase = self.termios.c_cc[VWERASE];
1067        let vkill = self.termios.c_cc[VKILL];
1068
1069        let count = if let Some(idx) = memchr3(verase, vwerase, vkill, &linebuf[..count]) {
1070            let idx = idx + self.linebuf_count;
1071
1072            let rev_idx = match self.linebuf[idx] {
1073                c if c == verase => {
1074                    if idx > 0 {
1075                        if self.linebuf[idx - 1] != b'\n' {
1076                            idx - 1
1077                        } else {
1078                            idx
1079                        }
1080                    } else {
1081                        0
1082                    }
1083                }
1084                c if c == vwerase => memrchr3(b'\n', b' ', b'\t', &self.linebuf[0..idx])
1085                    .map(|idx| idx + 1)
1086                    .unwrap_or(0),
1087                c if c == vkill => memrchr(b'\n', &self.linebuf[0..idx])
1088                    .map(|idx| idx + 1)
1089                    .unwrap_or(0),
1090                _ => panic!("invalid character"),
1091            };
1092
1093            self.linebuf.copy_within((idx + 1).., rev_idx);
1094            self.linebuf_count = self.linebuf_count.saturating_sub((idx - rev_idx).max(1));
1095
1096            count.saturating_sub((idx - rev_idx).max(1))
1097        } else {
1098            count
1099        };
1100
1101        self.linebuf_count += count;
1102        Ok(())
1103    }
1104
1105    fn drain_linebuf(&mut self, buf: &mut [u8]) -> (usize, bool) {
1106        let mut count = buf.len().min(self.linebuf_count);
1107        let linebuf = &self.linebuf[0..count];
1108
1109        let mut end = self.linebuf_count == BUF_SZ;
1110        let veof = self.termios.c_cc[VEOF];
1111
1112        if let Some(idx) = memchr2(b'\n', veof, linebuf) {
1113            if linebuf[idx] == b'\n' {
1114                count = idx + 1;
1115            } else if linebuf[idx] == veof {
1116                self.linebuf.copy_within((idx + 1).., idx);
1117                self.linebuf_count -= 1;
1118                count = idx;
1119            }
1120            end = true;
1121        }
1122
1123        if end {
1124            let linebuf = &self.linebuf[0..count];
1125            (&mut buf[0..count]).copy_from_slice(linebuf);
1126            self.linebuf.copy_within(count.., 0);
1127            self.linebuf_count -= count;
1128            (count, end)
1129        } else {
1130            (0, false)
1131        }
1132    }
1133
1134    pub fn read_canon(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
1135        let mut total = 0;
1136        while buf.len() > 0 {
1137            self.refill_linebuf()?;
1138            if self.linebuf_count == 0 {
1139                if total == 0 {
1140                    return Err(ErrorKind::WouldBlock.into());
1141                }
1142                return Ok(total);
1143            }
1144
1145            let (count, end) = self.drain_linebuf(buf);
1146
1147            buf = &mut buf[count..];
1148            total += count;
1149            if end {
1150                return Ok(total);
1151            }
1152        }
1153        Ok(total)
1154    }
1155
1156    pub fn pending_linebuf(&self) -> usize {
1157        self.linebuf_count
1158    }
1159
1160    pub fn read_raw(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
1161        let mut total = 0;
1162        while buf.len() > 0 {
1163            let thisread = match self.reader.read(buf) {
1164                Ok(l) => l,
1165                Err(e) if e.kind() == ErrorKind::WouldBlock => {
1166                    if total > 0 {
1167                        return Ok(total);
1168                    } else {
1169                        return Err(e);
1170                    }
1171                }
1172                Err(e) => return Err(e),
1173            };
1174
1175            if thisread == 0 {
1176                return Ok(total);
1177            }
1178
1179            // this might squash characters
1180            let thisread = input_map(&self.termios, &mut buf[0..thisread]);
1181
1182            total += thisread;
1183            buf = &mut buf[thisread..];
1184        }
1185        Ok(total)
1186    }
1187}
1188
1189fn input_map(termios: &libc::termios, mut buf: &mut [u8]) -> usize {
1190    let nl_to_cr = termios.c_iflag & INLCR != 0;
1191    let ignore_cr = termios.c_iflag & IGNCR != 0;
1192    let cr_to_nl = termios.c_iflag & ICRNL != 0;
1193
1194    let search_ln = nl_to_cr;
1195    let search_cr = ignore_cr || cr_to_nl;
1196
1197    if !search_cr && !search_ln {
1198        return buf.len();
1199    }
1200
1201    let mut total = 0;
1202    while buf.len() > 0 {
1203        let idx = if search_ln && search_cr {
1204            memchr::memchr2(b'\r', b'\n', buf)
1205        } else if search_cr {
1206            memchr::memchr(b'\r', buf)
1207        } else if search_ln {
1208            memchr::memchr(b'\n', buf)
1209        } else {
1210            unreachable!()
1211        };
1212
1213        if let Some(idx) = idx {
1214            let len = match buf[idx] {
1215                b'\r' if ignore_cr => {
1216                    buf.copy_within((idx + 1).., idx);
1217                    let newend = buf.len() - 1;
1218                    buf = &mut buf[idx..newend];
1219                    idx
1220                }
1221                b'\r' if cr_to_nl => {
1222                    buf[idx] = b'\n';
1223                    buf = &mut buf[(idx + 1)..];
1224                    idx + 1
1225                }
1226                b'\n' if nl_to_cr && ignore_cr => {
1227                    buf.copy_within((idx + 1).., idx);
1228                    let newend = buf.len() - 1;
1229                    buf = &mut buf[idx..newend];
1230                    idx
1231                }
1232                b'\n' if nl_to_cr => {
1233                    buf[idx] = b'\r';
1234                    buf = &mut buf[(idx + 1)..];
1235                    idx + 1
1236                }
1237                _ => {
1238                    panic!("unexpected character");
1239                }
1240            };
1241            total += len;
1242        } else {
1243            total += buf.len();
1244            return total;
1245        }
1246    }
1247
1248    total
1249}
1250
1251impl<R: Read> Read for InputConverter<R> {
1252    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
1253        if self.termios.c_lflag & ICANON != 0 {
1254            self.read_canon(buf)
1255        } else {
1256            self.read_raw(buf)
1257        }
1258    }
1259}
1260
1261pub mod more_tests {
1262    use std::io::{Cursor, Seek};
1263
1264    use libc::{ICANON, ICRNL, IGNCR, INLCR, OCRNL, ONLCR, VEOF, VERASE, VKILL, VWERASE, termios};
1265
1266    use crate::pty::{InputConverter, OutputConverter};
1267
1268    fn test_output_processing(oflag: u32, input: &[u8], expected: &[u8]) {
1269        let t = termios {
1270            c_iflag: 0,
1271            c_oflag: oflag,
1272            c_cflag: 0,
1273            c_lflag: 0,
1274            c_cc: [0; _],
1275            __c_ispeed: 0,
1276            __c_ospeed: 0,
1277            c_line: 0,
1278        };
1279        let buf = &mut [1u8; 1024] as &mut [u8];
1280        let mut cursor = Cursor::new(buf);
1281        let mut converter = OutputConverter::new(t, &mut cursor);
1282        let _written = converter.write_bytes_processed(&input).unwrap();
1283        let written = cursor.position() as usize;
1284        cursor.rewind().unwrap();
1285        let buf = cursor.get_ref();
1286        assert_eq!(&buf[0..written], expected);
1287    }
1288
1289    fn test_input_processing(iflag: u32, mut input: &[u8], expected: &[u8]) {
1290        let t = termios {
1291            c_iflag: iflag,
1292            c_oflag: 0,
1293            c_cflag: 0,
1294            c_lflag: 0,
1295            c_cc: [0; _],
1296            __c_ispeed: 0,
1297            __c_ospeed: 0,
1298            c_line: 0,
1299        };
1300        let mut converter = InputConverter::new(t, &mut input);
1301        let mut buf = [0u8; 1024];
1302        let read = converter.read_raw(&mut buf).unwrap();
1303        assert_eq!(&buf[0..read], expected);
1304    }
1305
1306    fn test_canon(iflag: u32, mut input: &[u8], expected: &[&[u8]]) {
1307        let mut t = termios {
1308            c_iflag: iflag,
1309            c_oflag: 0,
1310            c_cflag: 0,
1311            c_lflag: ICANON,
1312            c_cc: [0; _],
1313            __c_ispeed: 0,
1314            __c_ospeed: 0,
1315            c_line: 0,
1316        };
1317        t.c_cc[VEOF] = 0x4;
1318        t.c_cc[VERASE] = 0x8;
1319        t.c_cc[VWERASE] = 0x15;
1320        t.c_cc[VKILL] = 0x17;
1321        let mut converter = InputConverter::new(t, &mut input);
1322        for expected in expected {
1323            let mut buf = [0u8; 1024];
1324            let read = converter.read_canon(&mut buf).unwrap();
1325            assert_eq!(&buf[0..read], *expected);
1326        }
1327    }
1328
1329    pub fn test_raw_input_processing() {
1330        let input = b"start\ns\rend" as &[u8];
1331        test_input_processing(0, input, b"start\ns\rend");
1332
1333        test_input_processing(ICRNL, input, b"start\ns\nend");
1334        test_input_processing(INLCR, input, b"start\rs\rend");
1335        test_input_processing(IGNCR, input, b"start\nsend");
1336        test_input_processing(IGNCR | INLCR, input, b"startsend");
1337        test_input_processing(IGNCR | ICRNL, input, b"start\nsend");
1338
1339        let input = b"nothing" as &[u8];
1340        test_input_processing(ICRNL, input, b"nothing");
1341        test_input_processing(INLCR, input, b"nothing");
1342        test_input_processing(IGNCR, input, b"nothing");
1343        test_input_processing(IGNCR | INLCR, input, b"nothing");
1344        test_input_processing(IGNCR | ICRNL, input, b"nothing");
1345
1346        let input = b"\n\r" as &[u8];
1347        test_input_processing(ICRNL, input, b"\n\n");
1348        test_input_processing(INLCR, input, b"\r\r");
1349        test_input_processing(IGNCR, input, b"\n");
1350        test_input_processing(IGNCR | INLCR, input, b"");
1351        test_input_processing(IGNCR | ICRNL, input, b"\n");
1352    }
1353
1354    pub fn test_canon_input() {
1355        let input = b"first\nsecond\nthird" as &[u8];
1356        test_canon(0, input, &[b"first\n", b"second\n"]);
1357
1358        let input = b"first\nsecond\nthird\n" as &[u8];
1359        test_canon(0, input, &[b"first\n", b"second\n", b"third\n"]);
1360
1361        let input = b"first\x04second\n" as &[u8];
1362        test_canon(0, input, &[b"first", b"second\n"]);
1363
1364        let input = b"first" as &[u8];
1365        test_canon(0, input, &[]);
1366
1367        let input = b"\x04" as &[u8];
1368        test_canon(0, input, &[]);
1369
1370        let input = b"test words\x08S\n" as &[u8];
1371        test_canon(0, input, &[b"test wordS\n"]);
1372
1373        let input = b"test\n\x08S\n" as &[u8];
1374        test_canon(0, input, &[b"test\n", b"S\n"]);
1375
1376        let input = b"test words\x15S\n" as &[u8];
1377        test_canon(0, input, &[b"test S\n"]);
1378
1379        let input = b"test\n\x15S\n" as &[u8];
1380        test_canon(0, input, &[b"test\n", b"S\n"]);
1381
1382        let input = b"test words\x17S\n" as &[u8];
1383        test_canon(0, input, &[b"S\n"]);
1384
1385        let input = b"test\n\x17S\n" as &[u8];
1386        test_canon(0, input, &[b"test\n", b"S\n"]);
1387    }
1388
1389    pub fn test_output() {
1390        let input = b"start\ns\rend" as &[u8];
1391        test_output_processing(0, input, b"start\ns\rend");
1392
1393        test_output_processing(OCRNL, input, b"start\ns\nend");
1394        test_output_processing(ONLCR, input, b"start\r\ns\rend");
1395        test_output_processing(ONLCR | OCRNL, input, b"start\r\ns\r\nend");
1396    }
1397}