texlang_stdlib/
conditional.rs

1//! Control flow primitives (if, else, switch)
2//!
3//! This module contains implementations of generic conditional commands,
4//!   as well as a mechanism for adding new conditional commands outside
5//!   of this module.
6//! See the [Condition] trait for information on adding new commands.
7
8use std::cell::RefCell;
9use texlang::parse::Ordering;
10use texlang::prelude as txl;
11use texlang::traits::*;
12use texlang::*;
13
14const ELSE_DOC: &str = "Start the else branch of a conditional or switch statement";
15const IFCASE_DOC: &str = "Begin a switch statement";
16const IFNUM_DOC: &str = "Compare two variables";
17const IFODD_DOC: &str = "Check if a variable is odd";
18const IFTRUE_DOC: &str = "Evaluate the true branch";
19const IFFALSE_DOC: &str = "Evaluate the false branch";
20const FI_DOC: &str = "End a conditional or switch statement";
21const OR_DOC: &str = "Begin the next branch of a switch statement";
22
23/// A component for keeping track of conditional branches as they are expanded.
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Default)]
26pub struct Component {
27    // Branches is a stack where each element corresponds to a conditional that is currently
28    // expanding. A nested conditional is further up the stack than the conditional it is
29    // nested in.
30    //
31    // This stack is used to
32    // verify that \else and \fi tokens are valid; i.e., if a \else is encountered, the current
33    // conditional must be true otherwise the \else is invalid. For correct TeX code, the stack
34    // is never actually used.
35    //
36    // Because the conditional commands are expansion commands, they cannot get a mutable reference
37    // to the state. We thus wrap the branches in a ref cell to support mutating them through
38    // an immutable reference.
39    #[cfg_attr(
40        feature = "serde",
41        serde(
42            serialize_with = "serialize_branches",
43            deserialize_with = "deserialize_branches"
44        )
45    )]
46    branches: RefCell<Vec<Branch>>,
47
48    // We cache the tag values inside the component for performance reasons.
49    #[cfg_attr(feature = "serde", serde(skip))]
50    tags: Tags,
51}
52
53#[cfg(feature = "serde")]
54fn serialize_branches<S>(input: &RefCell<Vec<Branch>>, serializer: S) -> Result<S::Ok, S::Error>
55where
56    S: serde::Serializer,
57{
58    use serde::Serialize;
59    let slice: &[Branch] = &input.borrow();
60    slice.serialize(serializer)
61}
62
63#[cfg(feature = "serde")]
64fn deserialize_branches<'de, D>(deserializer: D) -> Result<RefCell<Vec<Branch>>, D::Error>
65where
66    D: serde::Deserializer<'de>,
67{
68    use serde::Deserialize;
69    let vec = Vec::<Branch>::deserialize(deserializer)?;
70    Ok(RefCell::new(vec))
71}
72
73struct Tags {
74    if_tag: command::Tag,
75    else_tag: command::Tag,
76    or_tag: command::Tag,
77    fi_tag: command::Tag,
78}
79
80impl Default for Tags {
81    fn default() -> Self {
82        Self {
83            if_tag: IF_TAG.get(),
84            else_tag: ELSE_TAG.get(),
85            or_tag: OR_TAG.get(),
86            fi_tag: FI_TAG.get(),
87        }
88    }
89}
90
91#[derive(Debug)]
92#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
93enum BranchKind {
94    // The true branch of an if conditional.
95    True,
96    // The false branch of an if conditional, or the default branch of a switch statement.
97    Else,
98    // A regular case brach of a switch statement.
99    Switch,
100}
101
102#[derive(Debug)]
103#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
104struct Branch {
105    _token: token::Token,
106    kind: BranchKind,
107}
108
109fn push_branch<S: HasComponent<Component>>(input: &mut vm::ExpansionInput<S>, branch: Branch) {
110    input.state().component().branches.borrow_mut().push(branch)
111}
112
113fn pop_branch<S: HasComponent<Component>>(input: &mut vm::ExpansionInput<S>) -> Option<Branch> {
114    input.state().component().branches.borrow_mut().pop()
115}
116
117static IF_TAG: command::StaticTag = command::StaticTag::new();
118static ELSE_TAG: command::StaticTag = command::StaticTag::new();
119static OR_TAG: command::StaticTag = command::StaticTag::new();
120static FI_TAG: command::StaticTag = command::StaticTag::new();
121
122// The `true_case` function is executed whenever a conditional evaluates to true.
123fn true_case<S: HasComponent<Component>>(
124    token: token::Token,
125    input: &mut vm::ExpansionInput<S>,
126) -> txl::Result<()> {
127    push_branch(
128        input,
129        Branch {
130            _token: token,
131            kind: BranchKind::True,
132        },
133    );
134    Ok(())
135}
136
137// The `false_case` function is executed whenever a conditional evaluates to false.
138//
139// The function scans forward in the input stream, discarding all tokens, until it encounters
140// either a \else or \fi command.
141fn false_case<S: HasComponent<Component>>(
142    original_token: token::Token,
143    input: &mut vm::ExpansionInput<S>,
144) -> txl::Result<()> {
145    let mut depth = 0;
146    loop {
147        let token = input
148            .unexpanded()
149            .next_or_err(FalseBranchEndOfInputError {})?;
150        if let token::Value::CommandRef(command_ref) = &token.value() {
151            let tag = input.commands_map().get_tag(command_ref);
152            if tag == Some(input.state().component().tags.else_tag) && depth == 0 {
153                push_branch(
154                    input,
155                    Branch {
156                        _token: original_token,
157                        kind: BranchKind::Else,
158                    },
159                );
160                return Ok(());
161            }
162            if tag == Some(input.state().component().tags.if_tag) {
163                depth += 1;
164            }
165            if tag == Some(input.state().component().tags.fi_tag) {
166                depth -= 1;
167                if depth < 0 {
168                    return Ok(());
169                }
170            }
171        }
172    }
173}
174
175#[derive(Debug)]
176struct FalseBranchEndOfInputError;
177
178impl error::EndOfInputError for FalseBranchEndOfInputError {
179    fn doing(&self) -> String {
180        r"skipping the true branch of an conditional command".into()
181    }
182    fn notes(&self) -> Vec<error::display::Note> {
183        vec![
184            "each `if` command must be terminated by a `fi` command, with an optional `else` in between".into(),
185            "this `if` command evaluated to false, and the input ended while skipping the true branch".into(),
186        ]
187    }
188}
189
190/// Logical condition used to build `if` conditional commands.
191///
192/// This trait can be used to build new conditional commands outside of this module.
193/// To do this, just create a new type (generally a ZST)
194///   and implement the [`evaluate`](Condition::evaluate) method of the trait for the type.
195/// A conditional command that uses the condition can then be obtained
196///   using the [`build_if_command`](Condition::build_if_command) method.
197pub trait Condition<S: HasComponent<Component>> {
198    /// Optional documentation for the command built using this condition.
199    const DOC: Option<&'static str> = None;
200
201    /// Evaluate the condition.
202    ///
203    /// Returns `true` if the condition is true, `false` if it is false,
204    ///   and an error otherwise.
205    fn evaluate(input: &mut vm::ExpansionInput<S>) -> txl::Result<bool>;
206
207    /// Build an `if` conditional command that uses this condition.
208    fn build_if_command() -> command::BuiltIn<S> {
209        let primitive_fn =
210            |token: token::Token, input: &mut vm::ExpansionInput<S>| -> txl::Result<()> {
211                match Self::evaluate(input)? {
212                    true => true_case(token, input),
213                    false => false_case(token, input),
214                }
215            };
216        let mut cmd = command::BuiltIn::new_expansion(primitive_fn).with_tag(IF_TAG.get());
217        if let Some(doc) = Self::DOC {
218            cmd = cmd.with_doc(doc)
219        };
220        cmd
221    }
222}
223
224struct IfTrue;
225
226impl<S: HasComponent<Component>> Condition<S> for IfTrue {
227    const DOC: Option<&'static str> = Some(IFTRUE_DOC);
228
229    fn evaluate(_: &mut vm::ExpansionInput<S>) -> txl::Result<bool> {
230        Ok(true)
231    }
232}
233
234/// Get the `\iftrue` primitive.
235pub fn get_iftrue<S: HasComponent<Component>>() -> command::BuiltIn<S> {
236    IfTrue::build_if_command()
237}
238
239struct IfFalse;
240
241impl<S: HasComponent<Component>> Condition<S> for IfFalse {
242    const DOC: Option<&'static str> = Some(IFFALSE_DOC);
243
244    fn evaluate(_: &mut vm::ExpansionInput<S>) -> txl::Result<bool> {
245        Ok(false)
246    }
247}
248
249/// Get the `\iffalse` primitive.
250pub fn get_iffalse<S: HasComponent<Component>>() -> command::BuiltIn<S> {
251    IfFalse::build_if_command()
252}
253
254struct IfNum;
255
256impl<S: HasComponent<Component>> Condition<S> for IfNum {
257    const DOC: Option<&'static str> = Some(IFNUM_DOC);
258
259    fn evaluate(input: &mut vm::ExpansionInput<S>) -> txl::Result<bool> {
260        let (a, o, b) = <(i32, Ordering, i32)>::parse(input)?;
261        Ok(a.cmp(&b) == o.0)
262    }
263}
264
265/// Get the `\ifnum` primitive.
266pub fn get_ifnum<S: HasComponent<Component>>() -> command::BuiltIn<S> {
267    IfNum::build_if_command()
268}
269
270struct IfOdd;
271
272impl<S: HasComponent<Component>> Condition<S> for IfOdd {
273    const DOC: Option<&'static str> = Some(IFODD_DOC);
274
275    fn evaluate(input: &mut vm::ExpansionInput<S>) -> txl::Result<bool> {
276        let n = i32::parse(input)?;
277        Ok((n % 2) == 1)
278    }
279}
280
281/// Get the `\ifodd` primitive.
282pub fn get_ifodd<S: HasComponent<Component>>() -> command::BuiltIn<S> {
283    IfOdd::build_if_command()
284}
285
286fn if_case_primitive_fn<S: HasComponent<Component>>(
287    ifcase_token: token::Token,
288    input: &mut vm::ExpansionInput<S>,
289) -> txl::Result<()> {
290    // TODO: should we reading the number from the unexpanded stream? Probably!
291    let total_cases_to_skip = i32::parse(input)?;
292    if total_cases_to_skip == 0 {
293        push_branch(
294            input,
295            Branch {
296                _token: ifcase_token,
297                kind: BranchKind::Switch,
298            },
299        );
300        return Ok(());
301    }
302    let mut cases_left_to_skip = total_cases_to_skip;
303    let mut depth = 0;
304    loop {
305        let token = input.unexpanded().next_or_err(IfCaseEndOfInputError {
306            total_cases_to_skip,
307            cases_left_to_skip,
308        })?;
309        if let token::Value::CommandRef(command_ref) = &token.value() {
310            let tag = input.commands_map().get_tag(command_ref);
311            if tag == Some(input.state().component().tags.or_tag) && depth == 0 {
312                cases_left_to_skip -= 1;
313                if cases_left_to_skip == 0 {
314                    push_branch(
315                        input,
316                        Branch {
317                            _token: ifcase_token,
318                            kind: BranchKind::Switch,
319                        },
320                    );
321                    return Ok(());
322                }
323            }
324            if tag == Some(input.state().component().tags.else_tag) && depth == 0 {
325                push_branch(
326                    input,
327                    Branch {
328                        _token: ifcase_token,
329                        kind: BranchKind::Else,
330                    },
331                );
332                return Ok(());
333            }
334            if tag == Some(input.state().component().tags.if_tag) {
335                depth += 1;
336            }
337            if tag == Some(input.state().component().tags.fi_tag) {
338                depth -= 1;
339                if depth < 0 {
340                    return Ok(());
341                }
342            }
343        }
344    }
345}
346
347#[derive(Debug)]
348struct IfCaseEndOfInputError {
349    total_cases_to_skip: i32,
350    cases_left_to_skip: i32,
351}
352
353impl error::EndOfInputError for IfCaseEndOfInputError {
354    fn doing(&self) -> String {
355        "skipping cases in an `ifcase` command".into()
356    }
357
358    fn notes(&self) -> Vec<error::display::Note> {
359        vec![
360            "each `ifcase` command must be matched by a `or`, `else` or `fi` command".into(),
361            format![
362                "this `ifcase` case evaluated to {}",
363                self.total_cases_to_skip
364            ]
365            .into(),
366            format![
367                "the input ended while skipping case {}",
368                self.total_cases_to_skip + 1 - self.cases_left_to_skip
369            ]
370            .into(),
371        ]
372    }
373}
374
375/// Get the `\ifcase` primitive.
376pub fn get_ifcase<S: HasComponent<Component>>() -> command::BuiltIn<S> {
377    command::BuiltIn::new_expansion(if_case_primitive_fn)
378        .with_tag(IF_TAG.get())
379        .with_doc(IFCASE_DOC)
380}
381
382fn or_primitive_fn<S: HasComponent<Component>>(
383    ifcase_token: token::Token,
384    input: &mut vm::ExpansionInput<S>,
385) -> txl::Result<()> {
386    let branch = pop_branch(input);
387    // For an or command to be valid, we must be in a switch statement
388    let is_valid = match branch {
389        None => false,
390        Some(branch) => matches!(branch.kind, BranchKind::Switch),
391    };
392    if !is_valid {
393        input.error(error::SimpleTokenError::new(
394            ifcase_token,
395            "unexpected `or` command",
396        ))?;
397        return Ok(());
398    }
399
400    let mut depth = 0;
401    loop {
402        let token = input.unexpanded().next_or_err(OrEndOfInputError {})?;
403        if let token::Value::CommandRef(command_ref) = &token.value() {
404            let tag = input.commands_map().get_tag(command_ref);
405            if tag == Some(input.state().component().tags.if_tag) {
406                depth += 1;
407            }
408            if tag == Some(input.state().component().tags.fi_tag) {
409                depth -= 1;
410                if depth < 0 {
411                    return Ok(());
412                }
413            }
414        }
415    }
416}
417
418#[derive(Debug)]
419struct OrEndOfInputError;
420
421impl error::EndOfInputError for OrEndOfInputError {
422    fn doing(&self) -> String {
423        "skipping cases in an `ifcase` command".into()
424    }
425    fn notes(&self) -> Vec<error::display::Note> {
426        vec![
427        "each `or` command must be terminated by a `fi` command".into(),
428        "this `or` corresponds to an `ifcase` command that evaluated to %d, and the input ended while skipping the remaining cases".into(),
429        "this is the `ifcase` command involved in the error:".into(),
430        "this is the `or` command involved in the error:".into(),
431        ]
432    }
433}
434
435/// Get the `\or` primitive.
436pub fn get_or<S: HasComponent<Component>>() -> command::BuiltIn<S> {
437    command::BuiltIn::new_expansion(or_primitive_fn)
438        .with_tag(OR_TAG.get())
439        .with_doc(OR_DOC)
440}
441
442fn else_primitive_fn<S: HasComponent<Component>>(
443    else_token: token::Token,
444    input: &mut vm::ExpansionInput<S>,
445) -> txl::Result<()> {
446    let branch = pop_branch(input);
447    // For else token to be valid, we must be in the true branch of a conditional
448    let is_valid = match branch {
449        None => false,
450        Some(branch) => matches!(branch.kind, BranchKind::True | BranchKind::Switch),
451    };
452    if !is_valid {
453        input.error(error::SimpleTokenError::new(
454            else_token,
455            "unexpected `else` command",
456        ))?;
457        return Ok(());
458    }
459
460    // Now consume all of the tokens until the next \fi
461    let mut depth = 0;
462    loop {
463        let token = input.unexpanded().next_or_err(ElseEndOfInputError {})?;
464        if let token::Value::CommandRef(command_ref) = &token.value() {
465            let tag = input.commands_map().get_tag(command_ref);
466            if tag == Some(input.state().component().tags.if_tag) {
467                depth += 1;
468            }
469            if tag == Some(input.state().component().tags.fi_tag) {
470                depth -= 1;
471                if depth < 0 {
472                    return Ok(());
473                }
474            }
475        }
476    }
477}
478
479#[derive(Debug)]
480struct ElseEndOfInputError;
481
482impl error::EndOfInputError for ElseEndOfInputError {
483    fn doing(&self) -> String {
484        r"skipping the false branch of an conditional command".into()
485    }
486
487    fn notes(&self) -> Vec<error::display::Note> {
488        vec![
489            "each `else` command must be terminated by a `fi` command".into(),
490            "this `else` corresponds to an `if` command that evaluated to true, and the input ended while skipping the false branch".into(),
491            "this is the `if` command involved in the error:".into(),
492            "this is the `else` command involved in the error:".into(),
493        ]
494    }
495}
496
497/// Get the `\else` primitive.
498pub fn get_else<S: HasComponent<Component>>() -> command::BuiltIn<S> {
499    command::BuiltIn::new_expansion(else_primitive_fn)
500        .with_tag(ELSE_TAG.get())
501        .with_doc(ELSE_DOC)
502}
503
504/// Get the `\fi` primitive.
505fn fi_primitive_fn<S: HasComponent<Component>>(
506    token: token::Token,
507    input: &mut vm::ExpansionInput<S>,
508) -> txl::Result<()> {
509    let branch = pop_branch(input);
510    // For a \fi primitive to be valid, we must be in a conditional.
511    // Note that we could be in the false branch: \iftrue\else\fi
512    // Or in the true branch: \iftrue\fi
513    // Or in a switch statement.
514    if branch.is_none() {
515        input.error(error::SimpleTokenError::new(
516            token,
517            "unexpected `fi` command",
518        ))?;
519    }
520    Ok(())
521}
522
523/// Get the `\fi` primitive.
524pub fn get_fi<S: HasComponent<Component>>() -> command::BuiltIn<S> {
525    command::BuiltIn::new_expansion(fi_primitive_fn)
526        .with_tag(FI_TAG.get())
527        .with_doc(FI_DOC)
528}
529
530#[cfg(test)]
531mod tests {
532    use std::collections::HashMap;
533
534    use super::*;
535    use texlang::vm::implement_has_component;
536    use texlang_testing::*;
537
538    #[derive(Default)]
539    #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
540    struct State {
541        conditional: Component,
542        testing: TestingComponent,
543    }
544
545    impl TexlangState for State {
546        fn recoverable_error_hook(
547            &self,
548            error: error::TracedTexError,
549        ) -> Result<(), Box<dyn error::TexError>> {
550            TestingComponent::recoverable_error_hook(self, error)
551        }
552    }
553
554    implement_has_component![State {
555        conditional: Component,
556        testing: TestingComponent,
557    }];
558
559    fn built_in_commands() -> HashMap<&'static str, command::BuiltIn<State>> {
560        HashMap::from([
561            ("else", get_else()),
562            ("fi", get_fi()),
563            ("ifcase", get_ifcase()),
564            ("iffalse", get_iffalse()),
565            ("ifnum", get_ifnum()),
566            ("ifodd", get_ifodd()),
567            ("iftrue", get_iftrue()),
568            ("or", get_or()),
569        ])
570    }
571
572    test_suite![
573        expansion_equality_tests(
574            (iftrue_base_case, r"\iftrue a\else b\fi c", r"ac"),
575            (iftrue_no_else, r"\iftrue a\fi c", r"ac"),
576            (
577                iftrue_skip_nested_ifs,
578                r"\iftrue a\else b\iftrue \else c\fi d\fi e",
579                r"ae"
580            ),
581            (iffalse_base_case, r"\iffalse a\else b\fi c", r"bc"),
582            (iffalse_no_else, r"\iffalse a\fi c", r"c"),
583            (
584                iffalse_skip_nested_ifs,
585                r"\iffalse \iftrue a\else b\fi c\else d\fi e",
586                r"de"
587            ),
588            (
589                iffalse_and_iftrue_1,
590                r"\iffalse a\else b\iftrue c\else d\fi e\fi f",
591                r"bcef"
592            ),
593            (
594                iffalse_and_iftrue_2,
595                r"\iftrue a\iffalse b\else c\fi d\else e\fi f",
596                r"acdf"
597            ),
598            (ifnum_less_than_true, r"\ifnum 4<5a\else b\fi c", r"ac"),
599            (ifnum_less_than_false, r"\ifnum 5<4a\else b\fi c", r"bc"),
600            (ifnum_equal_true_1, r"\ifnum 4=4a\else b\fi c", r"ac"),
601            (ifnum_equal_true_2, r"\ifnum 4=4a\else b\fi c", r"ac"),
602            (ifnum_equal_false, r"\ifnum 5=4a\else b\fi c", r"bc"),
603            (ifnum_greater_than_true, r"\ifnum 5>4a\else b\fi c", r"ac"),
604            (ifnum_greater_than_false, r"\ifnum 4>5a\else b\fi c", r"bc"),
605            (ifodd_odd, r"\ifodd 3a\else b\fi c", r"ac"),
606            (ifodd_even, r"\ifodd 4a\else b\fi c", r"bc"),
607            (ifcase_zero_no_ors, r"\ifcase 0 a\else b\fi c", r"ac"),
608            (ifcase_zero_one_or, r"\ifcase 0 a\or b\else c\fi d", r"ad"),
609            (ifcase_one, r"\ifcase 1 a\or b\else c\fi d", r"bd"),
610            (
611                ifcase_one_more_cases,
612                r"\ifcase 1 a\or b\or c\else d\fi e",
613                r"be"
614            ),
615            (ifcase_else_no_ors, r"\ifcase 1 a\else b\fi c", r"bc"),
616            (ifcase_else_one_or, r"\ifcase 2 a\or b\else c\fi d", r"cd"),
617            (ifcase_no_matching_case, r"\ifcase 3 a\or b\or c\fi d", r"d"),
618            (
619                ifcase_nested,
620                r"\ifcase 1 a\or b\ifcase 1 c\or d\or e\else f\fi g\or h\fi i",
621                r"bdgi"
622            ),
623        ),
624        serde_tests(
625            (serde_if, r"\iftrue true ", r"branch \else false branch \fi"),
626            (
627                serde_ifcase,
628                r"\ifcase 2 a\or b\or executed ",
629                r"case \or d \fi"
630            ),
631        ),
632        recoverable_failure_tests(
633            (
634                missing_op_1,
635                r"\ifnum 3 3 equal \else not equal \fi",
636                "equal",
637            ),
638            (
639                missing_op_2,
640                r"\ifnum 3 4 equal \else not equal \fi",
641                "not equal",
642            ),
643            (else_not_expected, r"a\else b", "ab"),
644            (fi_not_expected, r"a\fi b", "ab"),
645            (or_not_expected, r"a\or b", "ab"),
646        ),
647        fatal_error_tests(
648            (iftrue_end_of_input, r"\iftrue a\else b"),
649            (iffalse_end_of_input, r"\iffalse a"),
650        ),
651    ];
652}