1use std::any::Any;
56use std::marker::PhantomData;
57use std::ops::Deref;
58use std::os::raw::{c_int, c_void};
59use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60use std::ptr;
61use std::slice;
62use std::sync::Arc;
63
64use crate::ffi;
65use crate::ffi::sqlite3_context;
66use crate::ffi::sqlite3_value;
67
68use crate::context::set_result;
69use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};
70
71use crate::{str_to_cstring, Connection, Error, InnerConnection, Result};
72
73unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74 if let Error::SqliteFailure(ref err, ref s) = *err {
75 ffi::sqlite3_result_error_code(ctx, err.extended_code);
76 if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
77 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
78 }
79 } else {
80 ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION);
81 if let Ok(cstr) = str_to_cstring(&err.to_string()) {
82 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
83 }
84 }
85}
86
87unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
88 drop(Box::from_raw(p.cast::<T>()));
89}
90
91pub struct Context<'a> {
94 ctx: *mut sqlite3_context,
95 args: &'a [*mut sqlite3_value],
96}
97
98impl Context<'_> {
99 #[inline]
101 #[must_use]
102 pub fn len(&self) -> usize {
103 self.args.len()
104 }
105
106 #[inline]
108 #[must_use]
109 pub fn is_empty(&self) -> bool {
110 self.args.is_empty()
111 }
112
113 pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
123 let arg = self.args[idx];
124 let value = unsafe { ValueRef::from_value(arg) };
125 FromSql::column_result(value).map_err(|err| match err {
126 FromSqlError::InvalidType => {
127 Error::InvalidFunctionParameterType(idx, value.data_type())
128 }
129 FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
130 FromSqlError::Other(err) => {
131 Error::FromSqlConversionFailure(idx, value.data_type(), err)
132 }
133 FromSqlError::InvalidBlobSize { .. } => {
134 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
135 }
136 })
137 }
138
139 #[inline]
146 #[must_use]
147 pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
148 let arg = self.args[idx];
149 unsafe { ValueRef::from_value(arg) }
150 }
151
152 #[inline]
155 #[must_use]
156 pub fn get_arg(&self, idx: usize) -> SqlFnArg {
157 assert!(idx < self.len());
158 SqlFnArg { idx }
159 }
160
161 pub fn get_subtype(&self, idx: usize) -> std::os::raw::c_uint {
168 let arg = self.args[idx];
169 unsafe { ffi::sqlite3_value_subtype(arg) }
170 }
171
172 pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
185 where
186 T: Send + Sync + 'static,
187 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
188 F: FnOnce(ValueRef<'_>) -> Result<T, E>,
189 {
190 if let Some(v) = self.get_aux(arg)? {
191 Ok(v)
192 } else {
193 let vr = self.get_raw(arg as usize);
194 self.set_aux(
195 arg,
196 func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
197 )
198 }
199 }
200
201 pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
210 assert!(arg < self.len() as i32);
211 let orig: Arc<T> = Arc::new(value);
212 let inner: AuxInner = orig.clone();
213 let outer = Box::new(inner);
214 let raw: *mut AuxInner = Box::into_raw(outer);
215 unsafe {
216 ffi::sqlite3_set_auxdata(
217 self.ctx,
218 arg,
219 raw.cast(),
220 Some(free_boxed_value::<AuxInner>),
221 );
222 };
223 Ok(orig)
224 }
225
226 pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
236 assert!(arg < self.len() as i32);
237 let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
238 if p.is_null() {
239 Ok(None)
240 } else {
241 let v: AuxInner = AuxInner::clone(unsafe { &*p });
242 v.downcast::<T>()
243 .map(Some)
244 .map_err(|_| Error::GetAuxWrongType)
245 }
246 }
247
248 pub unsafe fn get_connection(&self) -> Result<ConnectionRef<'_>> {
255 let handle = ffi::sqlite3_context_db_handle(self.ctx);
256 Ok(ConnectionRef {
257 conn: Connection::from_handle(handle)?,
258 phantom: PhantomData,
259 })
260 }
261}
262
263pub struct ConnectionRef<'ctx> {
265 conn: Connection,
268 phantom: PhantomData<&'ctx Context<'ctx>>,
269}
270
271impl Deref for ConnectionRef<'_> {
272 type Target = Connection;
273
274 #[inline]
275 fn deref(&self) -> &Connection {
276 &self.conn
277 }
278}
279
280type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
281
282pub type SubType = Option<std::os::raw::c_uint>;
284
285pub trait SqlFnOutput {
287 fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>;
289}
290
291impl<T: ToSql> SqlFnOutput for T {
292 #[inline]
293 fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
294 ToSql::to_sql(self).map(|o| (o, None))
295 }
296}
297
298impl<T: ToSql> SqlFnOutput for (T, SubType) {
299 fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
300 ToSql::to_sql(&self.0).map(|o| (o, self.1))
301 }
302}
303
304pub struct SqlFnArg {
306 idx: usize,
307}
308impl ToSql for SqlFnArg {
309 fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
310 Ok(ToSqlOutput::Arg(self.idx))
311 }
312}
313
314unsafe fn sql_result<T: SqlFnOutput>(
315 ctx: *mut sqlite3_context,
316 args: &[*mut sqlite3_value],
317 r: Result<T>,
318) {
319 let t = r.as_ref().map(SqlFnOutput::to_sql);
320
321 match t {
322 Ok(Ok((ref value, sub_type))) => {
323 set_result(ctx, args, value);
324 if let Some(sub_type) = sub_type {
325 ffi::sqlite3_result_subtype(ctx, sub_type);
326 }
327 }
328 Ok(Err(err)) => report_error(ctx, &err),
329 Err(err) => report_error(ctx, err),
330 };
331}
332
333pub trait Aggregate<A, T>
339where
340 A: RefUnwindSafe + UnwindSafe,
341 T: SqlFnOutput,
342{
343 fn init(&self, ctx: &mut Context<'_>) -> Result<A>;
348
349 fn step(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>;
352
353 fn finalize(&self, ctx: &mut Context<'_>, acc: Option<A>) -> Result<T>;
363}
364
365#[cfg(feature = "window")]
368#[cfg_attr(docsrs, doc(cfg(feature = "window")))]
369pub trait WindowAggregate<A, T>: Aggregate<A, T>
370where
371 A: RefUnwindSafe + UnwindSafe,
372 T: SqlFnOutput,
373{
374 fn value(&self, acc: Option<&mut A>) -> Result<T>;
377
378 fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>;
380}
381
382bitflags::bitflags! {
383 #[derive(Clone, Copy, Debug)]
387 #[repr(C)]
388 pub struct FunctionFlags: c_int {
389 const SQLITE_UTF8 = ffi::SQLITE_UTF8;
391 const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE;
393 const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE;
395 const SQLITE_UTF16 = ffi::SQLITE_UTF16;
397 const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; const SQLITE_DIRECTONLY = 0x0000_0008_0000; const SQLITE_SUBTYPE = 0x0000_0010_0000; const SQLITE_INNOCUOUS = 0x0000_0020_0000; const SQLITE_RESULT_SUBTYPE = 0x0000_0100_0000; const SQLITE_SELFORDER1 = 0x0000_0200_0000; }
410}
411
412impl Default for FunctionFlags {
413 #[inline]
414 fn default() -> Self {
415 Self::SQLITE_UTF8
416 }
417}
418
419impl Connection {
420 #[inline]
458 pub fn create_scalar_function<F, T>(
459 &self,
460 fn_name: &str,
461 n_arg: c_int,
462 flags: FunctionFlags,
463 x_func: F,
464 ) -> Result<()>
465 where
466 F: Fn(&Context<'_>) -> Result<T> + Send + 'static,
467 T: SqlFnOutput,
468 {
469 self.db
470 .borrow_mut()
471 .create_scalar_function(fn_name, n_arg, flags, x_func)
472 }
473
474 #[inline]
481 pub fn create_aggregate_function<A, D, T>(
482 &self,
483 fn_name: &str,
484 n_arg: c_int,
485 flags: FunctionFlags,
486 aggr: D,
487 ) -> Result<()>
488 where
489 A: RefUnwindSafe + UnwindSafe,
490 D: Aggregate<A, T> + 'static,
491 T: SqlFnOutput,
492 {
493 self.db
494 .borrow_mut()
495 .create_aggregate_function(fn_name, n_arg, flags, aggr)
496 }
497
498 #[cfg(feature = "window")]
504 #[cfg_attr(docsrs, doc(cfg(feature = "window")))]
505 #[inline]
506 pub fn create_window_function<A, W, T>(
507 &self,
508 fn_name: &str,
509 n_arg: c_int,
510 flags: FunctionFlags,
511 aggr: W,
512 ) -> Result<()>
513 where
514 A: RefUnwindSafe + UnwindSafe,
515 W: WindowAggregate<A, T> + 'static,
516 T: SqlFnOutput,
517 {
518 self.db
519 .borrow_mut()
520 .create_window_function(fn_name, n_arg, flags, aggr)
521 }
522
523 #[inline]
534 pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> {
535 self.db.borrow_mut().remove_function(fn_name, n_arg)
536 }
537}
538
539impl InnerConnection {
540 fn create_scalar_function<F, T>(
562 &mut self,
563 fn_name: &str,
564 n_arg: c_int,
565 flags: FunctionFlags,
566 x_func: F,
567 ) -> Result<()>
568 where
569 F: Fn(&Context<'_>) -> Result<T> + Send + 'static,
570 T: SqlFnOutput,
571 {
572 unsafe extern "C" fn call_boxed_closure<F, T>(
573 ctx: *mut sqlite3_context,
574 argc: c_int,
575 argv: *mut *mut sqlite3_value,
576 ) where
577 F: Fn(&Context<'_>) -> Result<T>,
578 T: SqlFnOutput,
579 {
580 let args = slice::from_raw_parts(argv, argc as usize);
581 let r = catch_unwind(|| {
582 let boxed_f: *const F = ffi::sqlite3_user_data(ctx).cast::<F>();
583 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
584 let ctx = Context { ctx, args };
585 (*boxed_f)(&ctx)
586 });
587 let t = match r {
588 Err(_) => {
589 report_error(ctx, &Error::UnwindingPanic);
590 return;
591 }
592 Ok(r) => r,
593 };
594 sql_result(ctx, args, t);
595 }
596
597 let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
598 let c_name = str_to_cstring(fn_name)?;
599 let r = unsafe {
600 ffi::sqlite3_create_function_v2(
601 self.db(),
602 c_name.as_ptr(),
603 n_arg,
604 flags.bits(),
605 boxed_f.cast::<c_void>(),
606 Some(call_boxed_closure::<F, T>),
607 None,
608 None,
609 Some(free_boxed_value::<F>),
610 )
611 };
612 self.decode_result(r)
613 }
614
615 fn create_aggregate_function<A, D, T>(
616 &mut self,
617 fn_name: &str,
618 n_arg: c_int,
619 flags: FunctionFlags,
620 aggr: D,
621 ) -> Result<()>
622 where
623 A: RefUnwindSafe + UnwindSafe,
624 D: Aggregate<A, T> + 'static,
625 T: SqlFnOutput,
626 {
627 let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
628 let c_name = str_to_cstring(fn_name)?;
629 let r = unsafe {
630 ffi::sqlite3_create_function_v2(
631 self.db(),
632 c_name.as_ptr(),
633 n_arg,
634 flags.bits(),
635 boxed_aggr.cast::<c_void>(),
636 None,
637 Some(call_boxed_step::<A, D, T>),
638 Some(call_boxed_final::<A, D, T>),
639 Some(free_boxed_value::<D>),
640 )
641 };
642 self.decode_result(r)
643 }
644
645 #[cfg(feature = "window")]
646 fn create_window_function<A, W, T>(
647 &mut self,
648 fn_name: &str,
649 n_arg: c_int,
650 flags: FunctionFlags,
651 aggr: W,
652 ) -> Result<()>
653 where
654 A: RefUnwindSafe + UnwindSafe,
655 W: WindowAggregate<A, T> + 'static,
656 T: SqlFnOutput,
657 {
658 let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
659 let c_name = str_to_cstring(fn_name)?;
660 let r = unsafe {
661 ffi::sqlite3_create_window_function(
662 self.db(),
663 c_name.as_ptr(),
664 n_arg,
665 flags.bits(),
666 boxed_aggr.cast::<c_void>(),
667 Some(call_boxed_step::<A, W, T>),
668 Some(call_boxed_final::<A, W, T>),
669 Some(call_boxed_value::<A, W, T>),
670 Some(call_boxed_inverse::<A, W, T>),
671 Some(free_boxed_value::<W>),
672 )
673 };
674 self.decode_result(r)
675 }
676
677 fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
678 let c_name = str_to_cstring(fn_name)?;
679 let r = unsafe {
680 ffi::sqlite3_create_function_v2(
681 self.db(),
682 c_name.as_ptr(),
683 n_arg,
684 ffi::SQLITE_UTF8,
685 ptr::null_mut(),
686 None,
687 None,
688 None,
689 None,
690 )
691 };
692 self.decode_result(r)
693 }
694}
695
696unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
697 let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
698 if pac.is_null() {
699 return None;
700 }
701 Some(pac)
702}
703
704unsafe extern "C" fn call_boxed_step<A, D, T>(
705 ctx: *mut sqlite3_context,
706 argc: c_int,
707 argv: *mut *mut sqlite3_value,
708) where
709 A: RefUnwindSafe + UnwindSafe,
710 D: Aggregate<A, T>,
711 T: SqlFnOutput,
712{
713 let Some(pac) = aggregate_context(ctx, size_of::<*mut A>()) else {
714 ffi::sqlite3_result_error_nomem(ctx);
715 return;
716 };
717
718 let r = catch_unwind(|| {
719 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
720 assert!(
721 !boxed_aggr.is_null(),
722 "Internal error - null aggregate pointer"
723 );
724 let mut ctx = Context {
725 ctx,
726 args: slice::from_raw_parts(argv, argc as usize),
727 };
728
729 #[expect(clippy::unnecessary_cast)]
730 if (*pac as *mut A).is_null() {
731 *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?));
732 }
733
734 (*boxed_aggr).step(&mut ctx, &mut **pac)
735 });
736 let r = match r {
737 Err(_) => {
738 report_error(ctx, &Error::UnwindingPanic);
739 return;
740 }
741 Ok(r) => r,
742 };
743 match r {
744 Ok(_) => {}
745 Err(err) => report_error(ctx, &err),
746 };
747}
748
749#[cfg(feature = "window")]
750unsafe extern "C" fn call_boxed_inverse<A, W, T>(
751 ctx: *mut sqlite3_context,
752 argc: c_int,
753 argv: *mut *mut sqlite3_value,
754) where
755 A: RefUnwindSafe + UnwindSafe,
756 W: WindowAggregate<A, T>,
757 T: SqlFnOutput,
758{
759 let Some(pac) = aggregate_context(ctx, size_of::<*mut A>()) else {
760 ffi::sqlite3_result_error_nomem(ctx);
761 return;
762 };
763
764 let r = catch_unwind(|| {
765 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
766 assert!(
767 !boxed_aggr.is_null(),
768 "Internal error - null aggregate pointer"
769 );
770 let mut ctx = Context {
771 ctx,
772 args: slice::from_raw_parts(argv, argc as usize),
773 };
774 (*boxed_aggr).inverse(&mut ctx, &mut **pac)
775 });
776 let r = match r {
777 Err(_) => {
778 report_error(ctx, &Error::UnwindingPanic);
779 return;
780 }
781 Ok(r) => r,
782 };
783 match r {
784 Ok(_) => {}
785 Err(err) => report_error(ctx, &err),
786 };
787}
788
789unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
790where
791 A: RefUnwindSafe + UnwindSafe,
792 D: Aggregate<A, T>,
793 T: SqlFnOutput,
794{
795 let a: Option<A> = match aggregate_context(ctx, 0) {
798 Some(pac) =>
799 {
800 #[expect(clippy::unnecessary_cast)]
801 if (*pac as *mut A).is_null() {
802 None
803 } else {
804 let a = Box::from_raw(*pac);
805 Some(*a)
806 }
807 }
808 None => None,
809 };
810
811 let r = catch_unwind(|| {
812 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
813 assert!(
814 !boxed_aggr.is_null(),
815 "Internal error - null aggregate pointer"
816 );
817 let mut ctx = Context { ctx, args: &mut [] };
818 (*boxed_aggr).finalize(&mut ctx, a)
819 });
820 let t = match r {
821 Err(_) => {
822 report_error(ctx, &Error::UnwindingPanic);
823 return;
824 }
825 Ok(r) => r,
826 };
827 sql_result(ctx, &[], t);
828}
829
830#[cfg(feature = "window")]
831unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
832where
833 A: RefUnwindSafe + UnwindSafe,
834 W: WindowAggregate<A, T>,
835 T: SqlFnOutput,
836{
837 let pac = aggregate_context(ctx, 0).filter(|&pac| {
840 #[expect(clippy::unnecessary_cast)]
841 !(*pac as *mut A).is_null()
842 });
843
844 let r = catch_unwind(|| {
845 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
846 assert!(
847 !boxed_aggr.is_null(),
848 "Internal error - null aggregate pointer"
849 );
850 (*boxed_aggr).value(pac.map(|pac| &mut **pac))
851 });
852 let t = match r {
853 Err(_) => {
854 report_error(ctx, &Error::UnwindingPanic);
855 return;
856 }
857 Ok(r) => r,
858 };
859 sql_result(ctx, &[], t);
860}
861
862#[cfg(test)]
863mod test {
864 use regex::Regex;
865 use std::os::raw::c_double;
866
867 #[cfg(feature = "window")]
868 use crate::functions::WindowAggregate;
869 use crate::functions::{Aggregate, Context, FunctionFlags, SqlFnArg, SubType};
870 use crate::{Connection, Error, Result};
871
872 fn half(ctx: &Context<'_>) -> Result<c_double> {
873 assert!(!ctx.is_empty());
874 assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
875 assert!(unsafe {
876 ctx.get_connection()
877 .as_ref()
878 .map(::std::ops::Deref::deref)
879 .is_ok()
880 });
881 let value = ctx.get::<c_double>(0)?;
882 Ok(value / 2f64)
883 }
884
885 #[test]
886 fn test_function_half() -> Result<()> {
887 let db = Connection::open_in_memory()?;
888 db.create_scalar_function(
889 "half",
890 1,
891 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
892 half,
893 )?;
894 let result: f64 = db.one_column("SELECT half(6)")?;
895
896 assert!((3f64 - result).abs() < f64::EPSILON);
897 Ok(())
898 }
899
900 #[test]
901 fn test_remove_function() -> Result<()> {
902 let db = Connection::open_in_memory()?;
903 db.create_scalar_function(
904 "half",
905 1,
906 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
907 half,
908 )?;
909 let result: f64 = db.one_column("SELECT half(6)")?;
910 assert!((3f64 - result).abs() < f64::EPSILON);
911
912 db.remove_function("half", 1)?;
913 let result: Result<f64> = db.one_column("SELECT half(6)");
914 result.unwrap_err();
915 Ok(())
916 }
917
918 fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<bool> {
922 assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
923 type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
924 let regexp: std::sync::Arc<Regex> = ctx
925 .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
926 Ok(Regex::new(vr.as_str()?)?)
927 })?;
928
929 let is_match = {
930 let text = ctx
931 .get_raw(1)
932 .as_str()
933 .map_err(|e| Error::UserFunctionError(e.into()))?;
934
935 regexp.is_match(text)
936 };
937
938 Ok(is_match)
939 }
940
941 #[test]
942 fn test_function_regexp_with_auxiliary() -> Result<()> {
943 let db = Connection::open_in_memory()?;
944 db.execute_batch(
945 "BEGIN;
946 CREATE TABLE foo (x string);
947 INSERT INTO foo VALUES ('lisa');
948 INSERT INTO foo VALUES ('lXsi');
949 INSERT INTO foo VALUES ('lisX');
950 END;",
951 )?;
952 db.create_scalar_function(
953 "regexp",
954 2,
955 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
956 regexp_with_auxiliary,
957 )?;
958
959 let result: bool = db.one_column("SELECT regexp('l.s[aeiouy]', 'lisa')")?;
960
961 assert!(result);
962
963 let result: i64 =
964 db.one_column("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1")?;
965
966 assert_eq!(2, result);
967 Ok(())
968 }
969
970 #[test]
971 fn test_varargs_function() -> Result<()> {
972 let db = Connection::open_in_memory()?;
973 db.create_scalar_function(
974 "my_concat",
975 -1,
976 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
977 |ctx| {
978 let mut ret = String::new();
979
980 for idx in 0..ctx.len() {
981 let s = ctx.get::<String>(idx)?;
982 ret.push_str(&s);
983 }
984
985 Ok(ret)
986 },
987 )?;
988
989 for &(expected, query) in &[
990 ("", "SELECT my_concat()"),
991 ("onetwo", "SELECT my_concat('one', 'two')"),
992 ("abc", "SELECT my_concat('a', 'b', 'c')"),
993 ] {
994 let result: String = db.one_column(query)?;
995 assert_eq!(expected, result);
996 }
997 Ok(())
998 }
999
1000 #[test]
1001 fn test_get_aux_type_checking() -> Result<()> {
1002 let db = Connection::open_in_memory()?;
1003 db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
1004 if !ctx.get::<bool>(1)? {
1005 ctx.set_aux::<i64>(0, 100)?;
1006 } else {
1007 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
1008 assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100);
1009 }
1010 Ok(true)
1011 })?;
1012
1013 let res: bool =
1014 db.one_column("SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)")?;
1015 assert!(res);
1017 Ok(())
1018 }
1019
1020 struct Sum;
1021 struct Count;
1022
1023 impl Aggregate<i64, Option<i64>> for Sum {
1024 fn init(&self, _: &mut Context<'_>) -> Result<i64> {
1025 Ok(0)
1026 }
1027
1028 fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1029 *sum += ctx.get::<i64>(0)?;
1030 Ok(())
1031 }
1032
1033 fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> {
1034 Ok(sum)
1035 }
1036 }
1037
1038 impl Aggregate<i64, i64> for Count {
1039 fn init(&self, _: &mut Context<'_>) -> Result<i64> {
1040 Ok(0)
1041 }
1042
1043 fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1044 *sum += 1;
1045 Ok(())
1046 }
1047
1048 fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> {
1049 Ok(sum.unwrap_or(0))
1050 }
1051 }
1052
1053 #[test]
1054 fn test_sum() -> Result<()> {
1055 let db = Connection::open_in_memory()?;
1056 db.create_aggregate_function(
1057 "my_sum",
1058 1,
1059 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1060 Sum,
1061 )?;
1062
1063 let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1065 let result: Option<i64> = db.one_column(no_result)?;
1066 assert!(result.is_none());
1067
1068 let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1069 let result: i64 = db.one_column(single_sum)?;
1070 assert_eq!(4, result);
1071
1072 let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
1073 2, 1)";
1074 let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?;
1075 assert_eq!((4, 2), result);
1076 Ok(())
1077 }
1078
1079 #[test]
1080 fn test_count() -> Result<()> {
1081 let db = Connection::open_in_memory()?;
1082 db.create_aggregate_function(
1083 "my_count",
1084 -1,
1085 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1086 Count,
1087 )?;
1088
1089 let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1091 let result: i64 = db.one_column(no_result)?;
1092 assert_eq!(result, 0);
1093
1094 let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1095 let result: i64 = db.one_column(single_sum)?;
1096 assert_eq!(2, result);
1097 Ok(())
1098 }
1099
1100 #[cfg(feature = "window")]
1101 impl WindowAggregate<i64, Option<i64>> for Sum {
1102 fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1103 *sum -= ctx.get::<i64>(0)?;
1104 Ok(())
1105 }
1106
1107 fn value(&self, sum: Option<&mut i64>) -> Result<Option<i64>> {
1108 Ok(sum.copied())
1109 }
1110 }
1111
1112 #[test]
1113 #[cfg(feature = "window")]
1114 fn test_window() -> Result<()> {
1115 use fallible_iterator::FallibleIterator;
1116
1117 let db = Connection::open_in_memory()?;
1118 db.create_window_function(
1119 "sumint",
1120 1,
1121 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1122 Sum,
1123 )?;
1124 db.execute_batch(
1125 "CREATE TABLE t3(x, y);
1126 INSERT INTO t3 VALUES('a', 4),
1127 ('b', 5),
1128 ('c', 3),
1129 ('d', 8),
1130 ('e', 1);",
1131 )?;
1132
1133 let mut stmt = db.prepare(
1134 "SELECT x, sumint(y) OVER (
1135 ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1136 ) AS sum_y
1137 FROM t3 ORDER BY x;",
1138 )?;
1139
1140 let results: Vec<(String, i64)> = stmt
1141 .query([])?
1142 .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1143 .collect()?;
1144 let expected = vec![
1145 ("a".to_owned(), 9),
1146 ("b".to_owned(), 12),
1147 ("c".to_owned(), 16),
1148 ("d".to_owned(), 12),
1149 ("e".to_owned(), 9),
1150 ];
1151 assert_eq!(expected, results);
1152 Ok(())
1153 }
1154
1155 #[test]
1156 fn test_sub_type() -> Result<()> {
1157 fn test_getsubtype(ctx: &Context<'_>) -> Result<i32> {
1158 Ok(ctx.get_subtype(0) as i32)
1159 }
1160 fn test_setsubtype(ctx: &Context<'_>) -> Result<(SqlFnArg, SubType)> {
1161 use std::os::raw::c_uint;
1162 let value = ctx.get_arg(0);
1163 let sub_type = ctx.get::<c_uint>(1)?;
1164 Ok((value, Some(sub_type)))
1165 }
1166 let db = Connection::open_in_memory()?;
1167 db.create_scalar_function(
1168 "test_getsubtype",
1169 1,
1170 FunctionFlags::SQLITE_UTF8,
1171 test_getsubtype,
1172 )?;
1173 db.create_scalar_function(
1174 "test_setsubtype",
1175 2,
1176 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_RESULT_SUBTYPE,
1177 test_setsubtype,
1178 )?;
1179 let result: i32 = db.one_column("SELECT test_getsubtype('hello');")?;
1180 assert_eq!(0, result);
1181
1182 let result: i32 = db.one_column("SELECT test_getsubtype(test_setsubtype('hello',123));")?;
1183 assert_eq!(123, result);
1184
1185 Ok(())
1186 }
1187}