dynlink/
tls.rs

1//! Implements ELF TLS Variant II. I highly recommend reading the Fuchsia docs on thread-local
2//! storage as prep for this code.
3
4use std::{
5    alloc::Layout,
6    mem::{align_of, size_of},
7    ptr::NonNull,
8};
9
10use tracing::{error, trace};
11use twizzler_rt_abi::thread::TlsIndex;
12
13// re-export TLS TCB definition
14pub use crate::arch::Tcb;
15use crate::{
16    arch::{get_tls_variant, MINIMUM_TLS_ALIGNMENT},
17    compartment::Compartment,
18    DynlinkError, DynlinkErrorKind,
19};
20
21/// The TLS variant which determines the layout of the TLS region.
22pub enum TlsVariant {
23    Variant1,
24    Variant2,
25}
26
27#[derive(Clone)]
28pub(crate) struct TlsInfo {
29    // DTV needs to start with a generation count.
30    gen: u64,
31    // When adding modules to the static TLS region
32    alloc_size_mods: usize,
33    // Calculate the maximum alignment we'll need.
34    max_align: usize,
35    // Running offset counter.
36    offset: usize,
37    pub(crate) tls_mods: Vec<TlsModule>,
38}
39
40impl TlsInfo {
41    pub(crate) fn new(gen: u64) -> Self {
42        Self {
43            gen,
44            alloc_size_mods: Default::default(),
45            max_align: MINIMUM_TLS_ALIGNMENT,
46            tls_mods: Default::default(),
47            offset: 0,
48        }
49    }
50
51    pub(crate) fn clone_to_new_gen(&self, new_gen: u64) -> Self {
52        Self {
53            gen: new_gen,
54            alloc_size_mods: self.alloc_size_mods,
55            max_align: self.max_align,
56            tls_mods: self.tls_mods.clone(),
57            offset: self.offset,
58        }
59    }
60}
61
62#[derive(Debug, Clone)]
63pub struct TlsModule {
64    pub is_static: bool,
65    pub template_addr: usize,
66    pub template_filesz: usize,
67    pub template_memsz: usize,
68    pub template_align: usize,
69    pub offset: Option<usize>,
70    pub id: Option<TlsModId>,
71}
72
73impl TlsModule {
74    pub(crate) fn new_static(
75        template_addr: usize,
76        template_filesz: usize,
77        template_memsz: usize,
78        template_align: usize,
79    ) -> Self {
80        Self {
81            is_static: true,
82            template_addr,
83            template_filesz,
84            template_memsz,
85            template_align,
86            offset: None,
87            id: None,
88        }
89    }
90}
91
92impl TlsInfo {
93    pub fn insert(&mut self, mut tm: TlsModule) -> TlsModId {
94        // Track size and alignment requirement changes.
95        self.alloc_size_mods += tm.template_memsz;
96        self.max_align = std::cmp::max(self.max_align, tm.template_align);
97        self.max_align = self.max_align.next_power_of_two();
98
99        let id = match get_tls_variant() {
100            TlsVariant::Variant1 => {
101                // the first module is aligned and placed after the thread pointer
102                if self.tls_mods.is_empty() {
103                    // aarch64 reserves the first two words after the thread pointer
104                    self.offset = 16;
105                    // self.offset = 24;
106                    // make sure the current offset from the TP is aligned
107                    if !(self.offset as *const u8).is_aligned_to(tm.template_align) {
108                        let ptr = self.offset as *const u8;
109                        self.offset += ptr.align_offset(tm.template_align);
110                    }
111                    tm.offset = Some(self.offset);
112
113                    // account for the size of the module
114                    self.offset += tm.template_memsz;
115                } else {
116                    // make sure the offset is aligned
117                    if !(self.offset as *const u8).is_aligned_to(tm.template_align) {
118                        let ptr = self.offset as *const u8;
119                        self.offset += ptr.align_offset(tm.template_align);
120                    }
121
122                    // Set the offset so that the region starts aligned and has enough room.
123                    tm.offset = Some(self.offset);
124
125                    // account for the size of the module
126                    self.offset += tm.template_memsz;
127                }
128                // Save the module ID + 1 (leave one slot in the DTV for the generation count).
129                TlsModId((self.tls_mods.len() + 1) as u64, tm.offset.unwrap())
130            }
131            TlsVariant::Variant2 => {
132                // The first module is placed so that the region ends at the thread pointer. Other
133                // regions are just placed at semi-arbitrary positions below that,
134                // so we don't need to be as careful about them.
135                if self.tls_mods.is_empty() {
136                    // Set the offset so that the template ends at the thread pointer.
137                    self.offset = tm.template_memsz
138                        + ((tm.template_addr + tm.template_memsz).overflowing_neg().0
139                            & (tm.template_align - 1));
140                } else {
141                    // Set the offset so that the region starts aligned and has enough room.
142                    self.offset += tm.template_memsz + tm.template_align - 1;
143                    self.offset -= (self.offset + tm.template_addr) & (tm.template_align - 1);
144                }
145                tm.offset = Some(self.offset);
146                // Save the module ID + 1 (leave one slot in the DTV for the generation count).
147                TlsModId((self.tls_mods.len() + 1) as u64, self.offset)
148            }
149        };
150
151        tm.id = Some(id);
152        self.tls_mods.push(tm);
153        id
154    }
155
156    pub(crate) fn allocate<T>(
157        &self,
158        _comp: &Compartment,
159        alloc_base: NonNull<u8>,
160        tcb: T,
161    ) -> Result<TlsRegion, DynlinkError> {
162        // Given an allocation region, lets find all the interesting pointers.
163        let layout = self
164            .allocation_layout::<T>()
165            .map_err(|err| DynlinkErrorKind::LayoutError { err })?;
166
167        // thread pointer from base allocation
168        let thread_pointer = match get_tls_variant() {
169            TlsVariant::Variant1 => {
170                let mut base = usize::from(alloc_base.addr()) + size_of::<Tcb<T>>() - 16;
171                base += base & (layout.align() - 1);
172                NonNull::new(base as *mut u8).unwrap()
173            }
174            TlsVariant::Variant2 => {
175                let mut base = usize::from(alloc_base.addr()) + layout.size() - size_of::<Tcb<T>>();
176                // Align for the thread pointer and the 1st TLS region.
177                base -= base & (layout.align() - 1);
178                NonNull::new(base as *mut u8).unwrap()
179            }
180        };
181
182        let module_start = match get_tls_variant() {
183            TlsVariant::Variant1 => {
184                // where in the tls region we are after the TCB
185                let temp = unsafe { thread_pointer.as_ptr().add(16) };
186                // set it to align to the alignment of the first module
187                let padding_after_tcb = temp.align_offset(self.tls_mods[0].template_align);
188                NonNull::new(unsafe { temp.add(padding_after_tcb) }).unwrap()
189            }
190            TlsVariant::Variant2 => thread_pointer,
191        };
192
193        // calculate the start of the DTV
194        let dtv_ptr = match get_tls_variant() {
195            TlsVariant::Variant1 => {
196                // Variant 1 has the thread pointer pointing to the DTV pointer.
197                // offset at this point should be after the static TLS modules
198                let after_modules = unsafe { module_start.as_ptr().add(self.offset) };
199                let align_padding = after_modules.align_offset(align_of::<usize>());
200                let dtv_addr = unsafe { after_modules.add(align_padding).cast::<usize>() };
201                NonNull::new(dtv_addr).unwrap()
202            }
203            TlsVariant::Variant2 => alloc_base.cast(),
204        };
205
206        let tls_region = TlsRegion {
207            gen: self.gen,
208            module_top: module_start,
209            thread_pointer,
210            dtv: dtv_ptr,
211            num_dtv_entries: self.dtv_len(),
212            alloc_base,
213            layout,
214        };
215
216        // Each TLS module gets part of the region, data copied from the template.
217        for tm in &self.tls_mods {
218            if !tm.is_static {
219                error!("non-static TLS modules are not supported");
220                continue;
221            }
222            tls_region.copy_in_module(tm);
223            tls_region.set_dtv_entry(tm);
224        }
225
226        // Write the gen count.
227        trace!("setting dtv[0] to gen_count {}", self.gen);
228        unsafe { *tls_region.dtv.as_ptr() = self.gen as usize };
229
230        // Finally fill out the control block.
231        unsafe { (tls_region.get_thread_control_block()).write(Tcb::new(&tls_region, tcb)) };
232
233        Ok(tls_region)
234    }
235
236    fn dtv_len(&self) -> usize {
237        self.tls_mods.len() + 1
238    }
239
240    pub(crate) fn allocation_layout<T>(&self) -> Result<Layout, std::alloc::LayoutError> {
241        // Ensure that the alignment is enough for the control block.
242        let align = std::cmp::max(self.max_align, align_of::<Tcb<T>>()).next_power_of_two();
243        // Region needs space for each module, and we just assume they all need the max alignment.
244        // Add two to the mods length for calculating align padding, one for the dtv, one for the
245        // tcb.
246        let region_size = self.alloc_size_mods + align * (self.tls_mods.len() + 2);
247        let dtv_size = self.dtv_len() * size_of::<usize>();
248        // We also need space for the control block and the dtv.
249        let size = region_size + size_of::<Tcb<T>>() + dtv_size;
250        Layout::from_size_align(size, align)
251    }
252}
253
254impl<T> Tcb<T> {
255    pub(crate) fn new(tls_region: &TlsRegion, tcb_data: T) -> Self {
256        let self_ptr = unsafe { tls_region.get_thread_control_block() };
257        Self {
258            self_ptr,
259            dtv: tls_region.dtv.as_ptr(),
260            dtv_len: tls_region.num_dtv_entries,
261            runtime_data: tcb_data,
262        }
263    }
264
265    pub fn get_addr(&self, index: &TlsIndex) -> Option<*mut u8> {
266        unsafe {
267            let slice = core::slice::from_raw_parts(self.dtv, self.dtv_len);
268            Some((slice.get(index.mod_id)? + index.offset) as *mut _)
269        }
270    }
271}
272
273#[derive(Debug, Clone, Copy)]
274pub struct TlsModId(u64, usize);
275
276impl TlsModId {
277    pub fn tls_id(&self) -> u64 {
278        self.0
279    }
280
281    pub fn offset(&self) -> usize {
282        self.1
283    }
284}
285
286#[repr(C)]
287#[derive(Debug)]
288pub struct TlsRegion {
289    pub gen: u64,
290    pub layout: Layout,
291    pub alloc_base: NonNull<u8>,
292    pub thread_pointer: NonNull<u8>,
293    pub dtv: NonNull<usize>,
294    pub num_dtv_entries: usize,
295    pub module_top: NonNull<u8>,
296}
297
298impl TlsRegion {
299    pub fn alloc_base(&self) -> *mut u8 {
300        self.alloc_base.as_ptr()
301    }
302
303    pub fn alloc_layout(&self) -> Layout {
304        self.layout
305    }
306
307    pub fn get_thread_pointer_value(&self) -> usize {
308        self.thread_pointer.as_ptr() as usize
309    }
310
311    pub(crate) fn set_dtv_entry(&self, tm: &TlsModule) {
312        let dtv_slice =
313            unsafe { core::slice::from_raw_parts_mut(self.dtv.as_ptr(), self.num_dtv_entries) };
314        let dtv_idx = tm.id.as_ref().unwrap().tls_id() as usize;
315        let dtv_val = match get_tls_variant() {
316            TlsVariant::Variant1 => unsafe { self.thread_pointer.as_ptr().add(tm.offset.unwrap()) },
317            TlsVariant::Variant2 => unsafe { self.module_top.as_ptr().sub(tm.offset.unwrap()) },
318        };
319        trace!("setting dtv entry {} <= {:p}", dtv_idx, dtv_val);
320        dtv_slice[dtv_idx] = dtv_val as usize;
321    }
322
323    pub(crate) fn copy_in_module(&self, tm: &TlsModule) -> usize {
324        unsafe {
325            let start = match get_tls_variant() {
326                TlsVariant::Variant1 => self.thread_pointer.as_ptr().add(tm.offset.unwrap()),
327                TlsVariant::Variant2 => self.module_top.as_ptr().sub(tm.offset.unwrap()),
328            };
329            let src = tm.template_addr as *const u8;
330            trace!(
331                "copy in static region {:p} => {:p} (filesz={}, memsz={})",
332                src,
333                start,
334                tm.template_filesz,
335                tm.template_memsz
336            );
337            start.copy_from_nonoverlapping(src, tm.template_filesz);
338            start as usize
339        }
340    }
341}
342
343pub use crate::arch::get_current_thread_control_block;