trace/
tracer.rs

1use std::{
2    fmt::Debug,
3    sync::{
4        Condvar, Mutex,
5        atomic::{AtomicU64, Ordering},
6    },
7    thread::Builder,
8    usize,
9};
10
11use miette::IntoDiagnostic;
12use monitor_api::{CompartmentFlags, CompartmentHandle};
13use twizzler::{
14    BaseType, Invariant,
15    object::{MapFlags, ObjID, Object, ObjectBuilder, RawObject, TypedObject},
16};
17use twizzler_abi::{
18    syscall::{
19        ObjectCreate, PERTHREAD_TRACE_GEN_SAMPLE, ThreadSync, ThreadSyncFlags, ThreadSyncOp,
20        ThreadSyncReference, ThreadSyncSleep, ThreadSyncWake, TraceSpec, sys_ktrace,
21        sys_thread_change_state, sys_thread_set_trace_events, sys_thread_sync,
22    },
23    thread::ExecutionState,
24    trace::{TraceBase, TraceData, TraceEntryFlags, TraceEntryHead},
25};
26
27use crate::Cli;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30enum State {
31    Setup,
32    Ready,
33    Running,
34    Done,
35}
36
37pub struct TracingState {
38    objects: Vec<Object<BaseWrap>>,
39    end_point: u64,
40    pub total: u64,
41    state: State,
42}
43
44impl Debug for TracingState {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        write!(
47            f,
48            "TracingState {{ {} objects, end_point: {}, total: {}, state: {:?} }}",
49            self.objects.len(),
50            self.end_point,
51            self.total,
52            self.state
53        )
54    }
55}
56
57#[derive(BaseType, Invariant)]
58#[repr(transparent)]
59struct BaseWrap(TraceBase);
60
61impl TracingState {
62    fn new(specs: &[TraceSpec]) -> miette::Result<Self> {
63        let prime = ObjectBuilder::new(ObjectCreate::default())
64            .build(BaseWrap(TraceBase {
65                start: 0,
66                end: AtomicU64::new(0),
67            }))
68            .into_diagnostic()?;
69
70        for spec in specs {
71            sys_ktrace(prime.id(), Some(spec)).into_diagnostic()?;
72        }
73
74        Ok(Self {
75            objects: vec![prime],
76            end_point: 0,
77            total: 0,
78            state: State::Setup,
79        })
80    }
81
82    fn collect(&mut self) -> miette::Result<ThreadSyncSleep> {
83        let mut current = self.objects.last().unwrap();
84        let posted_end = current.base().0.end.load(Ordering::SeqCst);
85        let start_point = self.end_point.max(current.base().0.start);
86        tracing::trace!(
87            "collect {:x}: {:x} {:x}: {}",
88            self.end_point,
89            posted_end,
90            start_point,
91            self.objects.len()
92        );
93        if self.end_point != posted_end {
94            let amount = posted_end.saturating_sub(start_point);
95
96            if amount > 0 {
97                self.total += amount;
98
99                // scan for next object directives
100                let mut offset = 0usize;
101                while offset < amount as usize {
102                    let header = current
103                        .lea(start_point as usize + offset, size_of::<TraceEntryHead>())
104                        .unwrap()
105                        .cast::<TraceEntryHead>();
106                    let header = unsafe { &*header };
107                    if header.flags.contains(TraceEntryFlags::NEXT_OBJECT) {
108                        tracing::debug!("got next tracing object: {}", header.extra_or_next);
109                        let next = unsafe {
110                            Object::<BaseWrap>::map_unchecked(header.extra_or_next, MapFlags::READ)
111                                .into_diagnostic()
112                        }?;
113                        self.objects.push(next);
114                        current = self.objects.last().unwrap();
115                        self.end_point = current.base().0.start;
116                        return self.collect();
117                    } else {
118                        offset += size_of::<TraceEntryHead>();
119                        if header.flags.contains(TraceEntryFlags::HAS_DATA) {
120                            let data_header = current
121                                .lea(start_point as usize + offset, size_of::<TraceData<()>>())
122                                .unwrap()
123                                .cast::<TraceData<()>>();
124                            offset += (unsafe { *data_header }).len as usize;
125                        }
126                    }
127                }
128                if offset == amount as usize {
129                    self.end_point += amount;
130                }
131            }
132        }
133
134        Ok(ThreadSyncSleep::new(
135            ThreadSyncReference::Virtual(&current.base().0.end),
136            start_point,
137            ThreadSyncOp::Equal,
138            ThreadSyncFlags::empty(),
139        ))
140    }
141
142    pub fn data(&self) -> TraceDataIter<'_> {
143        TraceDataIter {
144            state: self,
145            pos: 0,
146            inner_pos: 0,
147        }
148    }
149}
150
151pub struct TraceDataIter<'a> {
152    state: &'a TracingState,
153    pos: usize,
154    inner_pos: u64,
155}
156
157#[allow(dead_code)]
158struct Tracer {
159    state: Mutex<TracingState>,
160    specs: Vec<TraceSpec>,
161    state_cv: Condvar,
162    notifier: AtomicU64,
163}
164
165impl<'a> Iterator for TraceDataIter<'a> {
166    type Item = (&'a TraceEntryHead, Option<&'a TraceData<()>>);
167
168    fn next(&mut self) -> Option<Self::Item> {
169        let obj = self.state.objects.get(self.pos)?;
170        let start_pos = self.inner_pos.max(obj.base().0.start);
171        self.inner_pos = start_pos;
172        let end = obj.base().0.end.load(Ordering::SeqCst);
173        if start_pos + size_of::<TraceEntryHead>() as u64 > end {
174            self.pos += 1;
175            self.inner_pos = 0;
176            return self.next();
177        }
178        let mut len = size_of::<TraceEntryHead>();
179        let header = obj
180            .lea(start_pos as usize, len)
181            .unwrap()
182            .cast::<TraceEntryHead>();
183        let header = unsafe { header.as_ref().unwrap() };
184        let data = if header.flags.contains(TraceEntryFlags::HAS_DATA) {
185            let data_header = obj
186                .lea(
187                    start_pos as usize + size_of::<TraceEntryHead>(),
188                    size_of::<TraceData<()>>(),
189                )
190                .unwrap()
191                .cast::<TraceData<()>>();
192            let data_header = unsafe { data_header.as_ref().unwrap() };
193            let data = obj
194                .lea(
195                    start_pos as usize + size_of::<TraceEntryHead>(),
196                    data_header.len as usize,
197                )
198                .unwrap()
199                .cast::<TraceData<()>>();
200            let data = unsafe { data.as_ref().unwrap() };
201            len += data.len as usize;
202            Some(data)
203        } else {
204            None
205        };
206
207        self.inner_pos += len as u64;
208
209        Some((header, data))
210    }
211}
212
213impl Tracer {
214    fn set_state(&self, new_state: State) {
215        tracing::trace!("setting tracing state: {:?}", new_state);
216        let mut guard = self.state.lock().unwrap();
217        guard.state = new_state;
218        self.state_cv.notify_all();
219    }
220
221    fn wait_for(&self, target_state: State) {
222        tracing::trace!("wait for tracing state: {:?}", target_state);
223        let mut guard = self.state.lock().unwrap();
224        while guard.state != target_state {
225            guard = self.state_cv.wait(guard).unwrap();
226        }
227    }
228
229    fn notify_exit(&self) {
230        let wake = ThreadSyncWake::new(ThreadSyncReference::Virtual(&self.notifier), usize::MAX);
231        self.notifier.store(1, Ordering::SeqCst);
232        let _ = sys_thread_sync(&mut [ThreadSync::new_wake(wake)], None).inspect_err(|e| {
233            tracing::warn!("failed to notify exit: {}", e);
234        });
235    }
236}
237
238fn collector(tracer: &Tracer) {
239    tracer.set_state(State::Ready);
240    loop {
241        let mut guard = tracer.state.lock().unwrap();
242        let Ok(waiter) = guard.collect().inspect_err(|e| {
243            tracing::error!("failed to collect trace data: {}", e);
244        }) else {
245            continue;
246        };
247
248        if tracer.notifier.load(Ordering::SeqCst) == 0 {
249            drop(guard);
250            let mut waiters = [
251                ThreadSync::new_sleep(waiter),
252                ThreadSync::new_sleep(ThreadSyncSleep::new(
253                    ThreadSyncReference::Virtual(&tracer.notifier),
254                    0,
255                    ThreadSyncOp::Equal,
256                    ThreadSyncFlags::empty(),
257                )),
258            ];
259            tracing::trace!(
260                "collector is waiting for data: {} {}",
261                waiters[0].ready(),
262                waiters[1].ready()
263            );
264            if waiters.iter().all(|w| !w.ready()) {
265                let _ = sys_thread_sync(&mut waiters, None).inspect_err(|e| {
266                    tracing::warn!("failed to thread sync: {}", e);
267                });
268            }
269        } else {
270            tracing::trace!("collector was notified of exit");
271
272            let _ = sys_ktrace(guard.objects.first().unwrap().id(), None).inspect_err(|e| {
273                tracing::error!("failed to disable tracing: {}", e);
274            });
275            let _ = guard.collect().inspect_err(|e| {
276                tracing::error!("failed to collect trace data: {}", e);
277            });
278            if guard.state == State::Done {
279                break;
280            }
281            drop(tracer.state_cv.wait(guard).unwrap());
282        }
283    }
284}
285
286pub fn start(
287    cli: &Cli,
288    comp: CompartmentHandle,
289    specs: Vec<TraceSpec>,
290) -> miette::Result<TracingState> {
291    let tracer = Tracer {
292        state: Mutex::new(TracingState::new(specs.as_slice())?),
293        specs,
294        state_cv: Condvar::new(),
295        notifier: AtomicU64::new(0),
296    };
297    std::thread::scope(|scope| {
298        let th_collector = Builder::new()
299            .name("trace-collector".to_owned())
300            .spawn_scoped(scope, || collector(&tracer))
301            .into_diagnostic()?;
302
303        tracer.wait_for(State::Ready);
304
305        for thread in comp.threads() {
306            let id: ObjID = thread.repr_id;
307            tracing::debug!("resuming compartment thread {}", id);
308            sys_thread_change_state(id, ExecutionState::Running).into_diagnostic()?;
309            if cli.prog.sample {
310                tracing::debug!("setting per-thread sampling for {}", id);
311                sys_thread_set_trace_events(id, PERTHREAD_TRACE_GEN_SAMPLE).into_diagnostic()?;
312            }
313        }
314        tracer.set_state(State::Running);
315
316        let mut flags = comp.info().flags;
317        while !flags.contains(CompartmentFlags::EXITED) {
318            flags = comp.wait(flags);
319        }
320        tracing::debug!("compartment exited");
321
322        tracer.set_state(State::Done);
323        tracer.notify_exit();
324
325        th_collector.join().unwrap();
326
327        std::io::Result::Ok(()).into_diagnostic()
328    })?;
329    tracer.state.into_inner().into_diagnostic()
330}