1use std::cell::RefCell;
9use texlang::parse::Ordering;
10use texlang::prelude as txl;
11use texlang::traits::*;
12use texlang::*;
13
14const ELSE_DOC: &str = "Start the else branch of a conditional or switch statement";
15const IFCASE_DOC: &str = "Begin a switch statement";
16const IFNUM_DOC: &str = "Compare two variables";
17const IFODD_DOC: &str = "Check if a variable is odd";
18const IFTRUE_DOC: &str = "Evaluate the true branch";
19const IFFALSE_DOC: &str = "Evaluate the false branch";
20const FI_DOC: &str = "End a conditional or switch statement";
21const OR_DOC: &str = "Begin the next branch of a switch statement";
22
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Default)]
26pub struct Component {
27 #[cfg_attr(
40 feature = "serde",
41 serde(
42 serialize_with = "serialize_branches",
43 deserialize_with = "deserialize_branches"
44 )
45 )]
46 branches: RefCell<Vec<Branch>>,
47
48 #[cfg_attr(feature = "serde", serde(skip))]
50 tags: Tags,
51}
52
53#[cfg(feature = "serde")]
54fn serialize_branches<S>(input: &RefCell<Vec<Branch>>, serializer: S) -> Result<S::Ok, S::Error>
55where
56 S: serde::Serializer,
57{
58 use serde::Serialize;
59 let slice: &[Branch] = &input.borrow();
60 slice.serialize(serializer)
61}
62
63#[cfg(feature = "serde")]
64fn deserialize_branches<'de, D>(deserializer: D) -> Result<RefCell<Vec<Branch>>, D::Error>
65where
66 D: serde::Deserializer<'de>,
67{
68 use serde::Deserialize;
69 let vec = Vec::<Branch>::deserialize(deserializer)?;
70 Ok(RefCell::new(vec))
71}
72
73struct Tags {
74 if_tag: command::Tag,
75 else_tag: command::Tag,
76 or_tag: command::Tag,
77 fi_tag: command::Tag,
78}
79
80impl Default for Tags {
81 fn default() -> Self {
82 Self {
83 if_tag: IF_TAG.get(),
84 else_tag: ELSE_TAG.get(),
85 or_tag: OR_TAG.get(),
86 fi_tag: FI_TAG.get(),
87 }
88 }
89}
90
91#[derive(Debug)]
92#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
93enum BranchKind {
94 True,
96 Else,
98 Switch,
100}
101
102#[derive(Debug)]
103#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
104struct Branch {
105 _token: token::Token,
106 kind: BranchKind,
107}
108
109fn push_branch<S: HasComponent<Component>>(input: &mut vm::ExpansionInput<S>, branch: Branch) {
110 input.state().component().branches.borrow_mut().push(branch)
111}
112
113fn pop_branch<S: HasComponent<Component>>(input: &mut vm::ExpansionInput<S>) -> Option<Branch> {
114 input.state().component().branches.borrow_mut().pop()
115}
116
117static IF_TAG: command::StaticTag = command::StaticTag::new();
118static ELSE_TAG: command::StaticTag = command::StaticTag::new();
119static OR_TAG: command::StaticTag = command::StaticTag::new();
120static FI_TAG: command::StaticTag = command::StaticTag::new();
121
122fn true_case<S: HasComponent<Component>>(
124 token: token::Token,
125 input: &mut vm::ExpansionInput<S>,
126) -> txl::Result<()> {
127 push_branch(
128 input,
129 Branch {
130 _token: token,
131 kind: BranchKind::True,
132 },
133 );
134 Ok(())
135}
136
137fn false_case<S: HasComponent<Component>>(
142 original_token: token::Token,
143 input: &mut vm::ExpansionInput<S>,
144) -> txl::Result<()> {
145 let mut depth = 0;
146 loop {
147 let token = input
148 .unexpanded()
149 .next_or_err(FalseBranchEndOfInputError {})?;
150 if let token::Value::CommandRef(command_ref) = &token.value() {
151 let tag = input.commands_map().get_tag(command_ref);
152 if tag == Some(input.state().component().tags.else_tag) && depth == 0 {
153 push_branch(
154 input,
155 Branch {
156 _token: original_token,
157 kind: BranchKind::Else,
158 },
159 );
160 return Ok(());
161 }
162 if tag == Some(input.state().component().tags.if_tag) {
163 depth += 1;
164 }
165 if tag == Some(input.state().component().tags.fi_tag) {
166 depth -= 1;
167 if depth < 0 {
168 return Ok(());
169 }
170 }
171 }
172 }
173}
174
175#[derive(Debug)]
176struct FalseBranchEndOfInputError;
177
178impl error::EndOfInputError for FalseBranchEndOfInputError {
179 fn doing(&self) -> String {
180 r"skipping the true branch of an conditional command".into()
181 }
182 fn notes(&self) -> Vec<error::display::Note> {
183 vec![
184 "each `if` command must be terminated by a `fi` command, with an optional `else` in between".into(),
185 "this `if` command evaluated to false, and the input ended while skipping the true branch".into(),
186 ]
187 }
188}
189
190pub trait Condition<S: HasComponent<Component>> {
198 const DOC: Option<&'static str> = None;
200
201 fn evaluate(input: &mut vm::ExpansionInput<S>) -> txl::Result<bool>;
206
207 fn build_if_command() -> command::BuiltIn<S> {
209 let primitive_fn =
210 |token: token::Token, input: &mut vm::ExpansionInput<S>| -> txl::Result<()> {
211 match Self::evaluate(input)? {
212 true => true_case(token, input),
213 false => false_case(token, input),
214 }
215 };
216 let mut cmd = command::BuiltIn::new_expansion(primitive_fn).with_tag(IF_TAG.get());
217 if let Some(doc) = Self::DOC {
218 cmd = cmd.with_doc(doc)
219 };
220 cmd
221 }
222}
223
224struct IfTrue;
225
226impl<S: HasComponent<Component>> Condition<S> for IfTrue {
227 const DOC: Option<&'static str> = Some(IFTRUE_DOC);
228
229 fn evaluate(_: &mut vm::ExpansionInput<S>) -> txl::Result<bool> {
230 Ok(true)
231 }
232}
233
234pub fn get_iftrue<S: HasComponent<Component>>() -> command::BuiltIn<S> {
236 IfTrue::build_if_command()
237}
238
239struct IfFalse;
240
241impl<S: HasComponent<Component>> Condition<S> for IfFalse {
242 const DOC: Option<&'static str> = Some(IFFALSE_DOC);
243
244 fn evaluate(_: &mut vm::ExpansionInput<S>) -> txl::Result<bool> {
245 Ok(false)
246 }
247}
248
249pub fn get_iffalse<S: HasComponent<Component>>() -> command::BuiltIn<S> {
251 IfFalse::build_if_command()
252}
253
254struct IfNum;
255
256impl<S: HasComponent<Component>> Condition<S> for IfNum {
257 const DOC: Option<&'static str> = Some(IFNUM_DOC);
258
259 fn evaluate(input: &mut vm::ExpansionInput<S>) -> txl::Result<bool> {
260 let (a, o, b) = <(i32, Ordering, i32)>::parse(input)?;
261 Ok(a.cmp(&b) == o.0)
262 }
263}
264
265pub fn get_ifnum<S: HasComponent<Component>>() -> command::BuiltIn<S> {
267 IfNum::build_if_command()
268}
269
270struct IfOdd;
271
272impl<S: HasComponent<Component>> Condition<S> for IfOdd {
273 const DOC: Option<&'static str> = Some(IFODD_DOC);
274
275 fn evaluate(input: &mut vm::ExpansionInput<S>) -> txl::Result<bool> {
276 let n = i32::parse(input)?;
277 Ok((n % 2) == 1)
278 }
279}
280
281pub fn get_ifodd<S: HasComponent<Component>>() -> command::BuiltIn<S> {
283 IfOdd::build_if_command()
284}
285
286fn if_case_primitive_fn<S: HasComponent<Component>>(
287 ifcase_token: token::Token,
288 input: &mut vm::ExpansionInput<S>,
289) -> txl::Result<()> {
290 let total_cases_to_skip = i32::parse(input)?;
292 if total_cases_to_skip == 0 {
293 push_branch(
294 input,
295 Branch {
296 _token: ifcase_token,
297 kind: BranchKind::Switch,
298 },
299 );
300 return Ok(());
301 }
302 let mut cases_left_to_skip = total_cases_to_skip;
303 let mut depth = 0;
304 loop {
305 let token = input.unexpanded().next_or_err(IfCaseEndOfInputError {
306 total_cases_to_skip,
307 cases_left_to_skip,
308 })?;
309 if let token::Value::CommandRef(command_ref) = &token.value() {
310 let tag = input.commands_map().get_tag(command_ref);
311 if tag == Some(input.state().component().tags.or_tag) && depth == 0 {
312 cases_left_to_skip -= 1;
313 if cases_left_to_skip == 0 {
314 push_branch(
315 input,
316 Branch {
317 _token: ifcase_token,
318 kind: BranchKind::Switch,
319 },
320 );
321 return Ok(());
322 }
323 }
324 if tag == Some(input.state().component().tags.else_tag) && depth == 0 {
325 push_branch(
326 input,
327 Branch {
328 _token: ifcase_token,
329 kind: BranchKind::Else,
330 },
331 );
332 return Ok(());
333 }
334 if tag == Some(input.state().component().tags.if_tag) {
335 depth += 1;
336 }
337 if tag == Some(input.state().component().tags.fi_tag) {
338 depth -= 1;
339 if depth < 0 {
340 return Ok(());
341 }
342 }
343 }
344 }
345}
346
347#[derive(Debug)]
348struct IfCaseEndOfInputError {
349 total_cases_to_skip: i32,
350 cases_left_to_skip: i32,
351}
352
353impl error::EndOfInputError for IfCaseEndOfInputError {
354 fn doing(&self) -> String {
355 "skipping cases in an `ifcase` command".into()
356 }
357
358 fn notes(&self) -> Vec<error::display::Note> {
359 vec![
360 "each `ifcase` command must be matched by a `or`, `else` or `fi` command".into(),
361 format![
362 "this `ifcase` case evaluated to {}",
363 self.total_cases_to_skip
364 ]
365 .into(),
366 format![
367 "the input ended while skipping case {}",
368 self.total_cases_to_skip + 1 - self.cases_left_to_skip
369 ]
370 .into(),
371 ]
372 }
373}
374
375pub fn get_ifcase<S: HasComponent<Component>>() -> command::BuiltIn<S> {
377 command::BuiltIn::new_expansion(if_case_primitive_fn)
378 .with_tag(IF_TAG.get())
379 .with_doc(IFCASE_DOC)
380}
381
382fn or_primitive_fn<S: HasComponent<Component>>(
383 ifcase_token: token::Token,
384 input: &mut vm::ExpansionInput<S>,
385) -> txl::Result<()> {
386 let branch = pop_branch(input);
387 let is_valid = match branch {
389 None => false,
390 Some(branch) => matches!(branch.kind, BranchKind::Switch),
391 };
392 if !is_valid {
393 input.error(error::SimpleTokenError::new(
394 ifcase_token,
395 "unexpected `or` command",
396 ))?;
397 return Ok(());
398 }
399
400 let mut depth = 0;
401 loop {
402 let token = input.unexpanded().next_or_err(OrEndOfInputError {})?;
403 if let token::Value::CommandRef(command_ref) = &token.value() {
404 let tag = input.commands_map().get_tag(command_ref);
405 if tag == Some(input.state().component().tags.if_tag) {
406 depth += 1;
407 }
408 if tag == Some(input.state().component().tags.fi_tag) {
409 depth -= 1;
410 if depth < 0 {
411 return Ok(());
412 }
413 }
414 }
415 }
416}
417
418#[derive(Debug)]
419struct OrEndOfInputError;
420
421impl error::EndOfInputError for OrEndOfInputError {
422 fn doing(&self) -> String {
423 "skipping cases in an `ifcase` command".into()
424 }
425 fn notes(&self) -> Vec<error::display::Note> {
426 vec![
427 "each `or` command must be terminated by a `fi` command".into(),
428 "this `or` corresponds to an `ifcase` command that evaluated to %d, and the input ended while skipping the remaining cases".into(),
429 "this is the `ifcase` command involved in the error:".into(),
430 "this is the `or` command involved in the error:".into(),
431 ]
432 }
433}
434
435pub fn get_or<S: HasComponent<Component>>() -> command::BuiltIn<S> {
437 command::BuiltIn::new_expansion(or_primitive_fn)
438 .with_tag(OR_TAG.get())
439 .with_doc(OR_DOC)
440}
441
442fn else_primitive_fn<S: HasComponent<Component>>(
443 else_token: token::Token,
444 input: &mut vm::ExpansionInput<S>,
445) -> txl::Result<()> {
446 let branch = pop_branch(input);
447 let is_valid = match branch {
449 None => false,
450 Some(branch) => matches!(branch.kind, BranchKind::True | BranchKind::Switch),
451 };
452 if !is_valid {
453 input.error(error::SimpleTokenError::new(
454 else_token,
455 "unexpected `else` command",
456 ))?;
457 return Ok(());
458 }
459
460 let mut depth = 0;
462 loop {
463 let token = input.unexpanded().next_or_err(ElseEndOfInputError {})?;
464 if let token::Value::CommandRef(command_ref) = &token.value() {
465 let tag = input.commands_map().get_tag(command_ref);
466 if tag == Some(input.state().component().tags.if_tag) {
467 depth += 1;
468 }
469 if tag == Some(input.state().component().tags.fi_tag) {
470 depth -= 1;
471 if depth < 0 {
472 return Ok(());
473 }
474 }
475 }
476 }
477}
478
479#[derive(Debug)]
480struct ElseEndOfInputError;
481
482impl error::EndOfInputError for ElseEndOfInputError {
483 fn doing(&self) -> String {
484 r"skipping the false branch of an conditional command".into()
485 }
486
487 fn notes(&self) -> Vec<error::display::Note> {
488 vec![
489 "each `else` command must be terminated by a `fi` command".into(),
490 "this `else` corresponds to an `if` command that evaluated to true, and the input ended while skipping the false branch".into(),
491 "this is the `if` command involved in the error:".into(),
492 "this is the `else` command involved in the error:".into(),
493 ]
494 }
495}
496
497pub fn get_else<S: HasComponent<Component>>() -> command::BuiltIn<S> {
499 command::BuiltIn::new_expansion(else_primitive_fn)
500 .with_tag(ELSE_TAG.get())
501 .with_doc(ELSE_DOC)
502}
503
504fn fi_primitive_fn<S: HasComponent<Component>>(
506 token: token::Token,
507 input: &mut vm::ExpansionInput<S>,
508) -> txl::Result<()> {
509 let branch = pop_branch(input);
510 if branch.is_none() {
515 input.error(error::SimpleTokenError::new(
516 token,
517 "unexpected `fi` command",
518 ))?;
519 }
520 Ok(())
521}
522
523pub fn get_fi<S: HasComponent<Component>>() -> command::BuiltIn<S> {
525 command::BuiltIn::new_expansion(fi_primitive_fn)
526 .with_tag(FI_TAG.get())
527 .with_doc(FI_DOC)
528}
529
530#[cfg(test)]
531mod tests {
532 use std::collections::HashMap;
533
534 use super::*;
535 use texlang::vm::implement_has_component;
536 use texlang_testing::*;
537
538 #[derive(Default)]
539 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
540 struct State {
541 conditional: Component,
542 testing: TestingComponent,
543 }
544
545 impl TexlangState for State {
546 fn recoverable_error_hook(
547 &self,
548 error: error::TracedTexError,
549 ) -> Result<(), Box<dyn error::TexError>> {
550 TestingComponent::recoverable_error_hook(self, error)
551 }
552 }
553
554 implement_has_component![State {
555 conditional: Component,
556 testing: TestingComponent,
557 }];
558
559 fn built_in_commands() -> HashMap<&'static str, command::BuiltIn<State>> {
560 HashMap::from([
561 ("else", get_else()),
562 ("fi", get_fi()),
563 ("ifcase", get_ifcase()),
564 ("iffalse", get_iffalse()),
565 ("ifnum", get_ifnum()),
566 ("ifodd", get_ifodd()),
567 ("iftrue", get_iftrue()),
568 ("or", get_or()),
569 ])
570 }
571
572 test_suite![
573 expansion_equality_tests(
574 (iftrue_base_case, r"\iftrue a\else b\fi c", r"ac"),
575 (iftrue_no_else, r"\iftrue a\fi c", r"ac"),
576 (
577 iftrue_skip_nested_ifs,
578 r"\iftrue a\else b\iftrue \else c\fi d\fi e",
579 r"ae"
580 ),
581 (iffalse_base_case, r"\iffalse a\else b\fi c", r"bc"),
582 (iffalse_no_else, r"\iffalse a\fi c", r"c"),
583 (
584 iffalse_skip_nested_ifs,
585 r"\iffalse \iftrue a\else b\fi c\else d\fi e",
586 r"de"
587 ),
588 (
589 iffalse_and_iftrue_1,
590 r"\iffalse a\else b\iftrue c\else d\fi e\fi f",
591 r"bcef"
592 ),
593 (
594 iffalse_and_iftrue_2,
595 r"\iftrue a\iffalse b\else c\fi d\else e\fi f",
596 r"acdf"
597 ),
598 (ifnum_less_than_true, r"\ifnum 4<5a\else b\fi c", r"ac"),
599 (ifnum_less_than_false, r"\ifnum 5<4a\else b\fi c", r"bc"),
600 (ifnum_equal_true_1, r"\ifnum 4=4a\else b\fi c", r"ac"),
601 (ifnum_equal_true_2, r"\ifnum 4=4a\else b\fi c", r"ac"),
602 (ifnum_equal_false, r"\ifnum 5=4a\else b\fi c", r"bc"),
603 (ifnum_greater_than_true, r"\ifnum 5>4a\else b\fi c", r"ac"),
604 (ifnum_greater_than_false, r"\ifnum 4>5a\else b\fi c", r"bc"),
605 (ifodd_odd, r"\ifodd 3a\else b\fi c", r"ac"),
606 (ifodd_even, r"\ifodd 4a\else b\fi c", r"bc"),
607 (ifcase_zero_no_ors, r"\ifcase 0 a\else b\fi c", r"ac"),
608 (ifcase_zero_one_or, r"\ifcase 0 a\or b\else c\fi d", r"ad"),
609 (ifcase_one, r"\ifcase 1 a\or b\else c\fi d", r"bd"),
610 (
611 ifcase_one_more_cases,
612 r"\ifcase 1 a\or b\or c\else d\fi e",
613 r"be"
614 ),
615 (ifcase_else_no_ors, r"\ifcase 1 a\else b\fi c", r"bc"),
616 (ifcase_else_one_or, r"\ifcase 2 a\or b\else c\fi d", r"cd"),
617 (ifcase_no_matching_case, r"\ifcase 3 a\or b\or c\fi d", r"d"),
618 (
619 ifcase_nested,
620 r"\ifcase 1 a\or b\ifcase 1 c\or d\or e\else f\fi g\or h\fi i",
621 r"bdgi"
622 ),
623 ),
624 serde_tests(
625 (serde_if, r"\iftrue true ", r"branch \else false branch \fi"),
626 (
627 serde_ifcase,
628 r"\ifcase 2 a\or b\or executed ",
629 r"case \or d \fi"
630 ),
631 ),
632 recoverable_failure_tests(
633 (
634 missing_op_1,
635 r"\ifnum 3 3 equal \else not equal \fi",
636 "equal",
637 ),
638 (
639 missing_op_2,
640 r"\ifnum 3 4 equal \else not equal \fi",
641 "not equal",
642 ),
643 (else_not_expected, r"a\else b", "ab"),
644 (fi_not_expected, r"a\fi b", "ab"),
645 (or_not_expected, r"a\or b", "ab"),
646 ),
647 fatal_error_tests(
648 (iftrue_end_of_input, r"\iftrue a\else b"),
649 (iffalse_end_of_input, r"\iffalse a"),
650 ),
651 ];
652}