dynlink/context/
load.rs

1use std::mem::size_of;
2
3use elf::{
4    abi::{DT_INIT, DT_INIT_ARRAY, DT_INIT_ARRAYSZ, DT_PREINIT_ARRAY, DT_PREINIT_ARRAYSZ, PT_TLS},
5    endian::NativeEndian,
6    file::Class,
7};
8use petgraph::stable_graph::NodeIndex;
9use secgate::RawSecGateInfo;
10use tracing::{debug, warn};
11use twizzler_rt_abi::core::CtorSet;
12
13use super::{Context, LoadedOrUnloaded};
14use crate::{
15    compartment::{Compartment, CompartmentId},
16    context::NewCompartmentFlags,
17    engines::{LoadCtx, LoadDirective, LoadFlags},
18    library::{AllowedGates, Library, LibraryId, SecgateInfo, UnloadedLibrary},
19    tls::TlsModule,
20    DynlinkError, DynlinkErrorKind, HeaderError,
21};
22
23#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq, Hash)]
24pub struct LoadIds {
25    pub comp: CompartmentId,
26    pub lib: LibraryId,
27}
28
29impl From<&Library> for LoadIds {
30    fn from(value: &Library) -> Self {
31        Self {
32            comp: value.comp_id,
33            lib: value.id(),
34        }
35    }
36}
37
38impl Context {
39    pub(crate) fn get_secgate_info(
40        &self,
41        libname: &str,
42        elf: &elf::ElfBytes<'_, NativeEndian>,
43        base_addr: usize,
44    ) -> Result<SecgateInfo, DynlinkError> {
45        let info = elf
46            .section_header_by_name(".twz_secgate_info")?
47            .map(|info| SecgateInfo {
48                info_addr: Some((info.sh_addr as usize) + base_addr),
49                num: (info.sh_size as usize) / core::mem::size_of::<RawSecGateInfo>(),
50            })
51            .unwrap_or_default();
52
53        debug!(
54            "{}: registered secure gate info: {} gates",
55            libname, info.num
56        );
57
58        Ok(info)
59    }
60    // Collect information about constructors.
61    pub(crate) fn get_ctor_info(
62        &self,
63        libname: &str,
64        elf: &elf::ElfBytes<'_, NativeEndian>,
65        base_addr: usize,
66    ) -> Result<CtorSet, DynlinkError> {
67        let dynamic = elf
68            .dynamic()?
69            .ok_or_else(|| DynlinkErrorKind::MissingSection {
70                name: "dynamic".to_string(),
71            })?;
72
73        // If this isn't present, just call it 0, since if there's an init_array, this entry must be
74        // present in valid ELF files.
75        let init_array_len = dynamic
76            .iter()
77            .find_map(|d| {
78                if d.d_tag == DT_INIT_ARRAYSZ {
79                    Some((d.d_val() as usize) / size_of::<usize>())
80                } else {
81                    None
82                }
83            })
84            .unwrap_or_default();
85
86        // Init array is a pointer to an array of function pointers.
87        let init_array = dynamic.iter().find_map(|d| {
88            if d.d_tag == DT_INIT_ARRAY && d.clone().d_ptr() != 0 {
89                Some(base_addr + d.d_ptr() as usize)
90            } else {
91                None
92            }
93        });
94
95        // Legacy _init call. Supported for, well, legacy.
96        let leg_init = dynamic.iter().find_map(|d| {
97            if d.d_tag == DT_INIT && d.clone().d_ptr() != 0 {
98                Some(base_addr + d.d_ptr() as usize)
99            } else {
100                None
101            }
102        });
103
104        if dynamic.iter().any(|d| d.d_tag == DT_PREINIT_ARRAY)
105            && dynamic
106                .iter()
107                .find(|d| d.d_tag == DT_PREINIT_ARRAYSZ)
108                .is_some_and(|d| d.d_val() > 0)
109        {
110            warn!("{}: PREINIT_ARRAY is unsupported", libname);
111        }
112
113        debug!(
114            "{}: ctor info: init_array: {:?} len={}, legacy: {:?}",
115            libname, init_array, init_array_len, leg_init
116        );
117        Ok(CtorSet {
118            legacy_init: leg_init.map(|x| unsafe { std::mem::transmute(x) }),
119            init_array: init_array.unwrap_or_default() as *mut _,
120            init_array_len,
121        })
122    }
123
124    // Load (map) a single library into memory via creating two objects, one for text, and one for
125    // data.
126    fn load(
127        &mut self,
128        comp_id: CompartmentId,
129        unlib: UnloadedLibrary,
130        idx: NodeIndex,
131        allowed_gates: AllowedGates,
132        load_ctx: &mut LoadCtx,
133    ) -> Result<Library, DynlinkError> {
134        let backing = self.engine.load_object(&unlib)?;
135        let elf = backing.get_elf()?;
136
137        // Step 0: sanity check the ELF header.
138
139        const EXPECTED_CLASS: Class = Class::ELF64;
140        const EXPECTED_VERSION: u32 = 1;
141        const EXPECTED_ABI: u8 = elf::abi::ELFOSABI_SYSV;
142        const EXPECTED_ABI_VERSION: u8 = 0;
143        const EXPECTED_TYPE: u16 = elf::abi::ET_DYN;
144
145        #[cfg(target_arch = "x86_64")]
146        const EXPECTED_MACHINE: u16 = elf::abi::EM_X86_64;
147
148        #[cfg(target_arch = "aarch64")]
149        const EXPECTED_MACHINE: u16 = elf::abi::EM_AARCH64;
150
151        if elf.ehdr.class != EXPECTED_CLASS {
152            return Err(DynlinkErrorKind::from(HeaderError::ClassMismatch {
153                expect: Class::ELF64,
154                got: elf.ehdr.class,
155            })
156            .into());
157        }
158
159        if elf.ehdr.version != EXPECTED_VERSION {
160            return Err(DynlinkErrorKind::from(HeaderError::VersionMismatch {
161                expect: EXPECTED_VERSION,
162                got: elf.ehdr.version,
163            })
164            .into());
165        }
166
167        if elf.ehdr.osabi != EXPECTED_ABI {
168            return Err(DynlinkErrorKind::from(HeaderError::OSABIMismatch {
169                expect: EXPECTED_ABI,
170                got: elf.ehdr.osabi,
171            })
172            .into());
173        }
174
175        if elf.ehdr.abiversion != EXPECTED_ABI_VERSION {
176            return Err(DynlinkErrorKind::from(HeaderError::ABIVersionMismatch {
177                expect: EXPECTED_ABI_VERSION,
178                got: elf.ehdr.abiversion,
179            })
180            .into());
181        }
182
183        if elf.ehdr.e_machine != EXPECTED_MACHINE {
184            return Err(DynlinkErrorKind::from(HeaderError::MachineMismatch {
185                expect: EXPECTED_MACHINE,
186                got: elf.ehdr.e_machine,
187            })
188            .into());
189        }
190
191        if elf.ehdr.e_type != EXPECTED_TYPE {
192            return Err(DynlinkErrorKind::from(HeaderError::ELFTypeMismatch {
193                expect: EXPECTED_TYPE,
194                got: elf.ehdr.e_type,
195            })
196            .into());
197        }
198
199        // Step 1: map the PT_LOAD directives to copy-from commands Twizzler can use for creating
200        // objects.
201        let directives: Vec<_> = elf
202            .segments()
203            .ok_or_else(|| DynlinkErrorKind::MissingSection {
204                name: "segment info".to_string(),
205            })?
206            .iter()
207            .filter(|p| p.p_type == elf::abi::PT_LOAD)
208            .map(|phdr| {
209                let ld = LoadDirective {
210                    load_flags: if phdr.p_flags & elf::abi::PF_W != 0 {
211                        LoadFlags::TARGETS_DATA
212                    } else {
213                        LoadFlags::empty()
214                    },
215                    vaddr: phdr.p_vaddr as usize,
216                    memsz: phdr.p_memsz as usize,
217                    offset: phdr.p_offset as usize,
218                    align: phdr.p_align as usize,
219                    filesz: phdr.p_filesz as usize,
220                };
221
222                debug!("{}: {:?}", unlib, ld);
223
224                ld
225            })
226            .collect();
227
228        // call the system impl to actually map things
229        let backings = self
230            .engine
231            .load_segments(&backing, &directives, comp_id, load_ctx)?;
232        if backings.is_empty() {
233            return Err(DynlinkErrorKind::NewBackingFail.into());
234        }
235        let base_addr = backings[0].load_addr();
236        debug!(
237            "{}: loaded to {:x} (data at {:x})",
238            unlib,
239            base_addr,
240            backings.get(1).map(|b| b.load_addr()).unwrap_or_default()
241        );
242
243        // Step 2: look for any TLS information, stored in program header PT_TLS.
244        let tls_phdr = elf
245            .segments()
246            .and_then(|phdrs| phdrs.iter().find(|phdr| phdr.p_type == PT_TLS));
247
248        let tls_id = tls_phdr
249            .map(|tls_phdr| {
250                let formatter = humansize::make_format(humansize::BINARY);
251                debug!(
252                    "{}: registering TLS data ({} total, {} copy)",
253                    unlib,
254                    formatter(tls_phdr.p_memsz),
255                    formatter(tls_phdr.p_filesz)
256                );
257                let tm = TlsModule::new_static(
258                    base_addr + tls_phdr.p_vaddr as usize,
259                    tls_phdr.p_filesz as usize,
260                    tls_phdr.p_memsz as usize,
261                    tls_phdr.p_align as usize,
262                );
263                let comp = &mut self.get_compartment_mut(comp_id)?;
264                Ok::<_, DynlinkError>(comp.insert(tm))
265            })
266            .transpose()?;
267
268        debug!("{}: got TLS ID {:?}", unlib, tls_id);
269
270        // Step 3: lookup constructor and secgate information for this library.
271        let ctor_info = self.get_ctor_info(&unlib.name, &elf, base_addr)?;
272        let secgate_info = self.get_secgate_info(&unlib.name, &elf, base_addr)?;
273
274        let comp = self.get_compartment(comp_id)?;
275        Ok(Library::new(
276            backing.full_name().to_owned(),
277            idx,
278            comp.id,
279            comp.name.clone(),
280            backing,
281            backings,
282            tls_id,
283            ctor_info,
284            secgate_info,
285            allowed_gates,
286        ))
287    }
288
289    fn find_cross_compartment_library(
290        &self,
291        unlib: &UnloadedLibrary,
292    ) -> Option<(NodeIndex, CompartmentId, &Compartment)> {
293        for (idx, comp) in self.compartments.iter().enumerate() {
294            if let Some(lib_id) = comp.1.library_names.get(&unlib.name) {
295                let lib = self.get_library(LibraryId(*lib_id));
296                if let Ok(lib) = lib {
297                    // Only allow cross-compartment refs for a library that has secure gates.
298                    if lib.secgate_info.info_addr.is_some() && lib.allows_gates() {
299                        return Some((*lib_id, CompartmentId(idx), comp.1));
300                    }
301                    return None;
302                }
303            }
304        }
305
306        None
307    }
308
309    fn has_secgate_info(&self, elf: &elf::ElfBytes<'_, NativeEndian>) -> bool {
310        elf.section_header_by_name(".twz_secgate_info")
311            .ok()
312            .is_some_and(|s| s.is_some())
313    }
314
315    fn select_compartment(
316        &mut self,
317        unlib: &UnloadedLibrary,
318        parent_comp_name: String,
319    ) -> Option<CompartmentId> {
320        let backing = self.engine.load_object(unlib).ok()?;
321        let elf = backing.get_elf().ok()?;
322        if self.has_secgate_info(&elf) {
323            let name = format!("{}::{}", parent_comp_name, unlib.name);
324            let id = self
325                .add_compartment(&name, NewCompartmentFlags::empty())
326                .ok()?;
327            tracing::debug!(
328                "creating new compartment {}({}) for library {}",
329                name,
330                id,
331                unlib.name
332            );
333            // TODO: Handle collisions
334            Some(id)
335        } else {
336            None
337        }
338    }
339
340    // Load a library and all its deps, using the supplied name resolution callback for deps.
341    pub(crate) fn load_library(
342        &mut self,
343        comp_id: CompartmentId,
344        root_unlib: UnloadedLibrary,
345        idx: NodeIndex,
346        allowed_gates: AllowedGates,
347        load_ctx: &mut LoadCtx,
348    ) -> Result<Vec<LoadIds>, DynlinkError> {
349        let root_comp_name = self.get_compartment(comp_id)?.name.clone();
350        debug!(
351            "loading library {} (idx = {:?}) in {}",
352            root_unlib, idx, root_comp_name
353        );
354        let mut ids = vec![];
355        // First load the main library.
356        let lib = self
357            .load(comp_id, root_unlib.clone(), idx, allowed_gates, load_ctx)
358            .map_err(|e| {
359                DynlinkError::new_collect(
360                    DynlinkErrorKind::LibraryLoadFail {
361                        library: root_unlib.clone(),
362                    },
363                    vec![e],
364                )
365            })?;
366        ids.push((&lib).into());
367
368        tracing::debug!("enumerating deps for {}", lib);
369        // Second, go through deps
370        let deps = self.enumerate_needed(&lib).map_err(|e| {
371            DynlinkError::new_collect(
372                DynlinkErrorKind::DepEnumerationFail {
373                    library: root_unlib.name.to_string(),
374                },
375                vec![e],
376            )
377        })?;
378        if !deps.is_empty() {
379            debug!("{}: loading {} dependencies", root_unlib, deps.len());
380        }
381        let deps = deps
382            .into_iter()
383            .map(|dep_unlib| {
384                // Dependency search + load alg:
385                // 1. Search library name in current compartment. If found, use that.
386                // 2. Fallback to searching globally for the name, by checking compartment by
387                //    compartment. If found, use that.
388                // 3. Okay, now we know we need to load the dep, so check if it can go in the
389                //    current compartment. If not, create a new compartment.
390                // 4. Finally, recurse to load it and its dependencies into either the current
391                //    compartment or the new one, if created.
392
393                let comp = self.get_compartment(comp_id)?;
394                let (existing_idx, load_comp) =
395                    if let Some(existing) = comp.library_names.get(&dep_unlib.name) {
396                        debug!(
397                            "{}: dep using existing library for {} (intra-compartment in {}): {:?}",
398                            root_unlib, dep_unlib.name, comp.name, existing
399                        );
400                        (Some(*existing), comp_id)
401                    } else if let Some((existing, other_comp_id, other_comp)) =
402                        self.find_cross_compartment_library(&dep_unlib)
403                    {
404                        debug!(
405                            "{}: dep using existing library for {} (cross-compartment to {}): {:?}",
406                            root_unlib, dep_unlib.name, other_comp.name, existing
407                        );
408                        (Some(existing), other_comp_id)
409                    } else {
410                        (
411                            None,
412                            self.select_compartment(&dep_unlib, root_comp_name.clone())
413                                .unwrap_or(comp_id),
414                        )
415                    };
416
417                // If we decided to use an existing library, then use that. Otherwise, load into the
418                // chosen compartment.
419                let idx = if let Some(existing_idx) = existing_idx {
420                    existing_idx
421                } else {
422                    let idx = self.add_library(dep_unlib.clone());
423
424                    let comp = self.get_compartment_mut(load_comp)?;
425                    comp.library_names.insert(dep_unlib.name.clone(), idx);
426                    let allowed_gates = if comp.id == comp_id {
427                        AllowedGates::Private
428                    } else {
429                        AllowedGates::Public
430                    };
431                    let mut recs = self
432                        .load_library(load_comp, dep_unlib.clone(), idx, allowed_gates, load_ctx)
433                        .map_err(|e| {
434                            tracing::error!("failed to load dependency for {}: {}", lib, e);
435                            DynlinkError::new_collect(
436                                DynlinkErrorKind::LibraryLoadFail {
437                                    library: dep_unlib.clone(),
438                                },
439                                vec![e],
440                            )
441                        })?;
442                    ids.append(&mut recs);
443                    idx
444                };
445                self.add_dep(lib.idx, idx);
446                Ok(idx)
447            })
448            .collect::<Vec<Result<_, DynlinkError>>>();
449
450        let _ = DynlinkError::collect(
451            DynlinkErrorKind::LibraryLoadFail {
452                library: root_unlib,
453            },
454            deps,
455        )?;
456
457        assert_eq!(idx, lib.idx);
458        self.library_deps[idx] = LoadedOrUnloaded::Loaded(lib);
459        Ok(ids)
460    }
461
462    /// Load a library into a given compartment.
463    pub fn load_library_in_compartment(
464        &mut self,
465        comp_id: CompartmentId,
466        unlib: UnloadedLibrary,
467        allowed_gates: AllowedGates,
468        load_ctx: &mut LoadCtx,
469    ) -> Result<Vec<LoadIds>, DynlinkError> {
470        let idx = self.add_library(unlib.clone());
471        // Step 1: insert into the compartment's library names.
472        let comp = self.get_compartment_mut(comp_id)?;
473
474        // At this level, it's an error to insert an already loaded library.
475        if comp.library_names.contains_key(&unlib.name) {
476            return Err(DynlinkErrorKind::NameAlreadyExists {
477                name: unlib.name.clone(),
478            }
479            .into());
480        }
481        comp.library_names.insert(unlib.name.clone(), idx);
482
483        // Step 2: load the library. This call recurses on dependencies.
484        self.load_library(comp_id, unlib.clone(), idx, allowed_gates, load_ctx)
485    }
486}