1use std::{
2 fmt::Debug,
3 sync::{
4 Condvar, Mutex,
5 atomic::{AtomicU64, Ordering},
6 },
7 thread::Builder,
8 usize,
9};
10
11use miette::IntoDiagnostic;
12use monitor_api::{CompartmentFlags, CompartmentHandle};
13use twizzler::{
14 BaseType, Invariant,
15 object::{MapFlags, ObjID, Object, ObjectBuilder, RawObject, TypedObject},
16};
17use twizzler_abi::{
18 syscall::{
19 ObjectCreate, PERTHREAD_TRACE_GEN_SAMPLE, ThreadSync, ThreadSyncFlags, ThreadSyncOp,
20 ThreadSyncReference, ThreadSyncSleep, ThreadSyncWake, TraceSpec, sys_ktrace,
21 sys_thread_change_state, sys_thread_set_trace_events, sys_thread_sync,
22 },
23 thread::ExecutionState,
24 trace::{TraceBase, TraceData, TraceEntryFlags, TraceEntryHead},
25};
26
27use crate::Cli;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30enum State {
31 Setup,
32 Ready,
33 Running,
34 Done,
35}
36
37pub struct TracingState {
38 objects: Vec<Object<BaseWrap>>,
39 end_point: u64,
40 pub total: u64,
41 state: State,
42}
43
44impl Debug for TracingState {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 write!(
47 f,
48 "TracingState {{ {} objects, end_point: {}, total: {}, state: {:?} }}",
49 self.objects.len(),
50 self.end_point,
51 self.total,
52 self.state
53 )
54 }
55}
56
57#[derive(BaseType, Invariant)]
58#[repr(transparent)]
59struct BaseWrap(TraceBase);
60
61impl TracingState {
62 fn new(specs: &[TraceSpec]) -> miette::Result<Self> {
63 let prime = ObjectBuilder::new(ObjectCreate::default())
64 .build(BaseWrap(TraceBase {
65 start: 0,
66 end: AtomicU64::new(0),
67 }))
68 .into_diagnostic()?;
69
70 for spec in specs {
71 sys_ktrace(prime.id(), Some(spec)).into_diagnostic()?;
72 }
73
74 Ok(Self {
75 objects: vec![prime],
76 end_point: 0,
77 total: 0,
78 state: State::Setup,
79 })
80 }
81
82 fn collect(&mut self) -> miette::Result<ThreadSyncSleep> {
83 let mut current = self.objects.last().unwrap();
84 let posted_end = current.base().0.end.load(Ordering::SeqCst);
85 let start_point = self.end_point.max(current.base().0.start);
86 tracing::trace!(
87 "collect {:x}: {:x} {:x}: {}",
88 self.end_point,
89 posted_end,
90 start_point,
91 self.objects.len()
92 );
93 if self.end_point != posted_end {
94 let amount = posted_end.saturating_sub(start_point);
95
96 if amount > 0 {
97 self.total += amount;
98
99 let mut offset = 0usize;
101 while offset < amount as usize {
102 let header = current
103 .lea(start_point as usize + offset, size_of::<TraceEntryHead>())
104 .unwrap()
105 .cast::<TraceEntryHead>();
106 let header = unsafe { &*header };
107 if header.flags.contains(TraceEntryFlags::NEXT_OBJECT) {
108 tracing::debug!("got next tracing object: {}", header.extra_or_next);
109 let next = unsafe {
110 Object::<BaseWrap>::map_unchecked(header.extra_or_next, MapFlags::READ)
111 .into_diagnostic()
112 }?;
113 self.objects.push(next);
114 current = self.objects.last().unwrap();
115 self.end_point = current.base().0.start;
116 return self.collect();
117 } else {
118 offset += size_of::<TraceEntryHead>();
119 if header.flags.contains(TraceEntryFlags::HAS_DATA) {
120 let data_header = current
121 .lea(start_point as usize + offset, size_of::<TraceData<()>>())
122 .unwrap()
123 .cast::<TraceData<()>>();
124 offset += (unsafe { *data_header }).len as usize;
125 }
126 }
127 }
128 if offset == amount as usize {
129 self.end_point += amount;
130 }
131 }
132 }
133
134 Ok(ThreadSyncSleep::new(
135 ThreadSyncReference::Virtual(¤t.base().0.end),
136 start_point,
137 ThreadSyncOp::Equal,
138 ThreadSyncFlags::empty(),
139 ))
140 }
141
142 pub fn data(&self) -> TraceDataIter<'_> {
143 TraceDataIter {
144 state: self,
145 pos: 0,
146 inner_pos: 0,
147 }
148 }
149}
150
151pub struct TraceDataIter<'a> {
152 state: &'a TracingState,
153 pos: usize,
154 inner_pos: u64,
155}
156
157#[allow(dead_code)]
158struct Tracer {
159 state: Mutex<TracingState>,
160 specs: Vec<TraceSpec>,
161 state_cv: Condvar,
162 notifier: AtomicU64,
163}
164
165impl<'a> Iterator for TraceDataIter<'a> {
166 type Item = (&'a TraceEntryHead, Option<&'a TraceData<()>>);
167
168 fn next(&mut self) -> Option<Self::Item> {
169 let obj = self.state.objects.get(self.pos)?;
170 let start_pos = self.inner_pos.max(obj.base().0.start);
171 self.inner_pos = start_pos;
172 let end = obj.base().0.end.load(Ordering::SeqCst);
173 if start_pos + size_of::<TraceEntryHead>() as u64 > end {
174 self.pos += 1;
175 self.inner_pos = 0;
176 return self.next();
177 }
178 let mut len = size_of::<TraceEntryHead>();
179 let header = obj
180 .lea(start_pos as usize, len)
181 .unwrap()
182 .cast::<TraceEntryHead>();
183 let header = unsafe { header.as_ref().unwrap() };
184 let data = if header.flags.contains(TraceEntryFlags::HAS_DATA) {
185 let data_header = obj
186 .lea(
187 start_pos as usize + size_of::<TraceEntryHead>(),
188 size_of::<TraceData<()>>(),
189 )
190 .unwrap()
191 .cast::<TraceData<()>>();
192 let data_header = unsafe { data_header.as_ref().unwrap() };
193 let data = obj
194 .lea(
195 start_pos as usize + size_of::<TraceEntryHead>(),
196 data_header.len as usize,
197 )
198 .unwrap()
199 .cast::<TraceData<()>>();
200 let data = unsafe { data.as_ref().unwrap() };
201 len += data.len as usize;
202 Some(data)
203 } else {
204 None
205 };
206
207 self.inner_pos += len as u64;
208
209 Some((header, data))
210 }
211}
212
213impl Tracer {
214 fn set_state(&self, new_state: State) {
215 tracing::trace!("setting tracing state: {:?}", new_state);
216 let mut guard = self.state.lock().unwrap();
217 guard.state = new_state;
218 self.state_cv.notify_all();
219 }
220
221 fn wait_for(&self, target_state: State) {
222 tracing::trace!("wait for tracing state: {:?}", target_state);
223 let mut guard = self.state.lock().unwrap();
224 while guard.state != target_state {
225 guard = self.state_cv.wait(guard).unwrap();
226 }
227 }
228
229 fn notify_exit(&self) {
230 let wake = ThreadSyncWake::new(ThreadSyncReference::Virtual(&self.notifier), usize::MAX);
231 self.notifier.store(1, Ordering::SeqCst);
232 let _ = sys_thread_sync(&mut [ThreadSync::new_wake(wake)], None).inspect_err(|e| {
233 tracing::warn!("failed to notify exit: {}", e);
234 });
235 }
236}
237
238fn collector(tracer: &Tracer) {
239 tracer.set_state(State::Ready);
240 loop {
241 let mut guard = tracer.state.lock().unwrap();
242 let Ok(waiter) = guard.collect().inspect_err(|e| {
243 tracing::error!("failed to collect trace data: {}", e);
244 }) else {
245 continue;
246 };
247
248 if tracer.notifier.load(Ordering::SeqCst) == 0 {
249 drop(guard);
250 let mut waiters = [
251 ThreadSync::new_sleep(waiter),
252 ThreadSync::new_sleep(ThreadSyncSleep::new(
253 ThreadSyncReference::Virtual(&tracer.notifier),
254 0,
255 ThreadSyncOp::Equal,
256 ThreadSyncFlags::empty(),
257 )),
258 ];
259 tracing::trace!(
260 "collector is waiting for data: {} {}",
261 waiters[0].ready(),
262 waiters[1].ready()
263 );
264 if waiters.iter().all(|w| !w.ready()) {
265 let _ = sys_thread_sync(&mut waiters, None).inspect_err(|e| {
266 tracing::warn!("failed to thread sync: {}", e);
267 });
268 }
269 } else {
270 tracing::trace!("collector was notified of exit");
271
272 let _ = sys_ktrace(guard.objects.first().unwrap().id(), None).inspect_err(|e| {
273 tracing::error!("failed to disable tracing: {}", e);
274 });
275 let _ = guard.collect().inspect_err(|e| {
276 tracing::error!("failed to collect trace data: {}", e);
277 });
278 if guard.state == State::Done {
279 break;
280 }
281 drop(tracer.state_cv.wait(guard).unwrap());
282 }
283 }
284}
285
286pub fn start(
287 cli: &Cli,
288 comp: CompartmentHandle,
289 specs: Vec<TraceSpec>,
290) -> miette::Result<TracingState> {
291 let tracer = Tracer {
292 state: Mutex::new(TracingState::new(specs.as_slice())?),
293 specs,
294 state_cv: Condvar::new(),
295 notifier: AtomicU64::new(0),
296 };
297 std::thread::scope(|scope| {
298 let th_collector = Builder::new()
299 .name("trace-collector".to_owned())
300 .spawn_scoped(scope, || collector(&tracer))
301 .into_diagnostic()?;
302
303 tracer.wait_for(State::Ready);
304
305 for thread in comp.threads() {
306 let id: ObjID = thread.repr_id;
307 tracing::debug!("resuming compartment thread {}", id);
308 sys_thread_change_state(id, ExecutionState::Running).into_diagnostic()?;
309 if cli.prog.sample {
310 tracing::debug!("setting per-thread sampling for {}", id);
311 sys_thread_set_trace_events(id, PERTHREAD_TRACE_GEN_SAMPLE).into_diagnostic()?;
312 }
313 }
314 tracer.set_state(State::Running);
315
316 let mut flags = comp.info().flags;
317 while !flags.contains(CompartmentFlags::EXITED) {
318 flags = comp.wait(flags);
319 }
320 tracing::debug!("compartment exited");
321
322 tracer.set_state(State::Done);
323 tracer.notify_exit();
324
325 th_collector.join().unwrap();
326
327 std::io::Result::Ok(()).into_diagnostic()
328 })?;
329 tracer.state.into_inner().into_diagnostic()
330}