Rust之美迭代器在算法中的应用

前言

在本文中我们将从一个简单的算法题引出Rust语言中迭代器相关的几个方法,帮助大家理解chars,all ,try_fold,enumerate,try_for_each,char_indices的用法。 ​

题目如下

我们定义,在以下情况时,单词的大写用法是正确的:
引用
全部字母都是大写,比如 "USA" 。
单词中所有字母都不是大写,比如 "leetcode" 。
如果单词不只含有一个字母,只有首字母大写, 比如 "Google" 。
给你一个字符串 word 。如果大写用法正确,返回 true ;否则,返回 false 。
引用
来源:力扣(LeetCode)
链接: https://leetcode-cn.com/probl...
著作权归领扣网络所有。商业转载请联系官方授权,非商业转载请注明出处。

分析
根据题意,满足条件的字符串就是

  • 要么字符串中的字符都是小写;
  • 要么都是大写;
  • 要么第一个是大写,剩下的是小写。

原始的解法
根据上面的分析,我们只需要做下面三个判断就可以了

  • word中字符如果都是大写,返回true
  • word中字符是否都是小写,返回true
  • word中首字符大写且剩余字符小写,返回true
  • 其它情况返回false
pub fn detect_capital_use(word: String) -> bool {
    if word.len() < 2 {
        return true;
    }

    // 判断都是大写
    let mut res = true;
    for c in word.as_bytes() {
        if c.is_ascii_lowercase() {
            res = false;
            break;
        }
    }
    if res {
        return res;
    }

    // 判断都是小写
    let mut res = true;
    for c in word.as_bytes() {
        if c.is_ascii_uppercase() {
            res = false;
            break;
        }
    }
    if res {
        return res;
    }

    // 判断首字母大写,剩余小写
    if word.as_bytes()[0].is_ascii_lowercase() {
        return false;
    }
    let mut res = true;
    for c in &word.as_bytes()[1..] {
        if c.is_ascii_uppercase() {
            res = false;
            break;
        }
    }
    if res {
        return res;
    }

    false
}

使用迭代器
上面的代码中使用了三次遍历,如果使用Rust中的迭代器,可以非常简洁,对应代码如下:

pub fn detect_capital_use(word: String) -> bool {

    if word.len() ==0{
        return true
    }
    if word.len() ==1{
        return true
    }

    let mut word1 = word.chars(); // 返回word中字符的迭代器
    if word1.all(|x|x.is_lowercase()){ // 都是小写
        return true 
    }

    let mut word1 = word.chars();
    if word1.all(|x|x.is_uppercase()){ // 都是大写
        return true 
    }

    let mut word1 = word.chars();
    let first_word = word1.next().unwrap();// 获取第一个字符
    if first_word.is_lowercase(){  //首字符大写
        return false
    }
    if word1.all(|x|x.is_lowercase()){ // 剩下的小写
        return true 
    }

    false

}

代码分析
上面代码中我们使用.chars()方法得到String的字符迭代器,然后利用了rust中迭代器的all 方法完成了该功能,整个代码逻辑非常贴合人类的自然语言,可读性非常强。下面我们先介绍下 all 方法: 迭代器的all 方法用来判断迭代器所遍历的所有项是否满足闭包的条件

  1. Tests if every element of the iterator matches a predicate.
  2. all() takes a closure that returns true or false. It applies this closure to each element of the iterator, and if they all return true, then so does all(). If any of them return false, it returns false.
  3. all() is short-circuiting; in other words, it will stop processing as soon as it finds a false, given that no matter what else happens, the result will also be false.
  4. An empty iterator returns true.

翻译过来就是

  1. 判断迭代器中的每一个元素是否满足断言(也就是传入的闭包函数)。
  2. all() 接受一个闭包作为参数,这个闭包返回true或false。all() 在迭代器遍历过程中,把每一个元素传入闭包,如果所有的元素传入闭包中都返回true,那么all() 就返回true,否则all() 返回false。
  3. all() 是短路;换句话说,在遍历的过程中,一旦某个元素传入闭包后返回false,就会立刻停止遍历,无论后面元素的结果是什么,最终all() 的结果是 false
  4. 一个空的迭代器的all() 永远返回 true

all() 对应的源码如下:

#[inline]
    #[stable(feature = "rust1", since = "1.0.0")]
    fn all(&mut self, f: F) -> bool
    where
        Self: Sized,
        F: FnMut(Self::Item) -> bool,
    {
        #[inline]
        fn check(mut f: impl FnMut(T) -> bool) -> impl FnMut((), T) -> ControlFlow<()> {
            move |(), x| {
                if f(x) { ControlFlow::CONTINUE } else { ControlFlow::BREAK }
            }
        }
        self.try_fold((), check(f)) == ControlFlow::CONTINUE
    }

可以看到all方法内部起始是使用了try_fold() 来实现。上面我们提到all() 是短路的,就是使用了try_fold的特性。

  • 首先all 方法将传入的闭包 f 封装成impl FnMut((), T) -> ControlFlow<()>类型的新闭包,这个新闭包在f(x) 为true的时候返回ControlFlow::CONTINUE,在为false的时候返回ControlFlow::BREAK, 而try_fold() 在闭包函数返回ControlFlow::BREAK的时候就会提前退出,实现短路的性质。
  • try_fold() 如果最终返回ControlFlow::CONTINUE,表示所有的元素执行f(x)返回true,如果返回ControlFlow::BREAK表示中间遇到了一个元素执行f(x)返回false.

接下来我们再仔细看下try_fold这个方法: ​

  • An iterator method that applies a function as long as it returns successfully, producing a single, final value.
  • try_fold() takes two arguments: an initial value, and a closure with two arguments: an ‘accumulator’, and an element. The closure either returns successfully, with the value that the accumulator should have for the next iteration, or it returns failure, with an error value that is propagated back to the caller immediately (short-circuiting).
  • The initial value is the value the accumulator will have on the first call. If applying the closure succeeded against every element of the iterator, try_fold() returns the final accumulator as success.
  • Folding is useful whenever you have a collection of something, and want to produce a single value from it.

翻译过来就是

  • try_fold 这个方法在迭代器上应用一个函数,如果这个函数一致返回成功,就继续执行直到返回一个最终的值,如果失败就提前退出。
  • try_fold 接受两个参数,第一个参数是初始值,第二个参数是一个闭包,而这个闭包需要传入两个参数:一个累加值,一个元素值。这个闭包如果返回成功的话,就会返回下次运算所需的累加值,如果返回失败,就会立刻将错误的值直接返回到调用者中(短路)
  • 初始值是第一次闭包调用时候使用的累加值。如果迭代器中每一个元素应用到闭包上后就成功,那么try_fold()执行结束后就会返回最总的累加值。
  • 当你有一个元素的机制,需要从这个集合中得到单个值的时候,Fold就会很有用。

try_fold的源码如下,也比较简单

#[inline]
    #[stable(feature = "iterator_try_fold", since = "1.27.0")]
    fn try_fold(&mut self, init: B, mut f: F) -> R
    where
        Self: Sized,
        F: FnMut(B, Self::Item) -> R,
        R: Try,
    {
        let mut accum = init;
        while let Some(x) = self.next() {
            accum = f(accum, x)?;
        }
        try { accum }
    }

可以看到try_fold内部通过迭代器的next方法不断获取迭代器中的每一个元素,执行函数f,accum = f(accum, x)?;,其中 ? 是rust中的一个特殊符号,表示如果f(accum,x) 返回ControlFlow::BREAK就退出当前循环并且将ControlFlow::BREAK作为 try_fold 的返回值。

简化全小写,全大写判断

上面代码中,判断word是否都是小写的逻辑判断,分为了两行,先通过chars获取迭代器,然后调用all方法判断是否都是小写,这块代码可以直接用一行代码替代 word.chars().all(|x|x.is_lowercase()), 判断word都是大写的判断同理,如下:

pub fn detect_capital_use(word: String) -> bool {

    if word.len() ==0{
        return true
    }
    if word.len() ==1{
        return true
    }

    if word.chars().all(|x|x.is_lowercase()){ // 如果都是小写
        return true 
    }

    if word.chars().all(|x|x.is_uppercase()){ // 如果都是大写
        return true 
    }

    let mut word1 = word.chars();
    let first_word = word1.next().unwrap(); // 获取第一个字符
    if first_word.is_lowercase(){  //如果第一个字符是小写,返回false
        return false
    }
    if word1.all(|x|x.is_lowercase()){  // 剩下的都是小写,返回true
        return true 
    }

    false

}

简化首字符大写,其它小写的判断

由于.chars得到的迭代器在遍历中只会返回每一个元素,并不会返回在原先集合中的索引值,所以我们先调用next获取第一个字符判断是否是大写,然后再用all判断剩下的字符是否都是小写。 如果我们可以同时获取索引和元素值,那么就可以在all的闭包中同时判断首字符和其它字符了。为了获取包含索引的迭代器,我们可以采用enumerate对迭代器进行封装,代码如下:

pub fn detect_capital_use(word: String) -> bool {
    if word.len() == 0 {
        return true;
    }
    if word.len() == 1 {
        return true;
    }

    if word.chars().all(|x| x.is_lowercase()) {// 如果都是小写
        return true;
    }

    if word.chars().all(|x| x.is_uppercase()) {// 如果都是大写
        return true;
    }

    if word.chars().enumerate().all(|(id, x)| {
        if id == 0 {
            x.is_uppercase()  // 第一个字符是大写
        } else {
            x.is_lowercase()  // 剩下的都是小写
        }
    }) {
        return true;
    }

    false
}

可以看到enumerate()每次迭代返回的元素是一个元组,元组的第一个元素就是索引值,我们就可以根据索引值和元素值判断了。enumerate()的方法返回的Enumerate类型的新的迭代器,对应源码如下:

pub struct Enumerate {
    iter: I,
    count: usize,
}
impl Enumerate {
    pub(in crate::iter) fn new(iter: I) -> Enumerate {
        Enumerate { iter, count: 0 }
    }
}

#[stable(feature = "rust1", since = "1.0.0")]
impl Iterator for Enumerate
where
    I: Iterator,
{
    type Item = (usize, ::Item);

    /// # Overflow Behavior
    ///
    /// The method does no guarding against overflows, so enumerating more than
    /// `usize::MAX` elements either produces the wrong result or panics. If
    /// debug assertions are enabled, a panic is guaranteed.
    ///
    /// # Panics
    ///
    /// Might panic if the index of the element overflows a `usize`.
    #[inline]
    #[rustc_inherit_overflow_checks]
    fn next(&mut self) -> Option<(usize, ::Item)> {
        let a = self.iter.next()?;
        let i = self.count;
        self.count += 1;
        Some((i, a))
    }

    ...
}

可以看到 Enumerate主要是将之前的迭代器进行了封装,然后内部维护了一个count来记录索引,在每次next的时候更新并作为元组的第一个元素返回。 将代码中三个逻辑判断合并得到:

pub fn detect_capital_use(word: String) -> bool {
    if word.len() == 0 {
        return true;
    }
    if word.len() == 1 {
        return true;
    }

    if word.chars().all(|x| x.is_lowercase()) // 如果都是小写
        || word.chars().all(|x| x.is_uppercase())// 如果都是小写
        || word.chars().enumerate().all(|(id, x)| {
            if id == 0 {
                x.is_uppercase() // 第一个字符是大写
            } else {
                x.is_lowercase() // 第二个字符是小写
            }
        })
    {
        return true;
    }

    false
}

由于空的迭代器的all永远返回true,上面的word长度的判断也可以省略,得到:

pub fn detect_capital_use(word: String) -> bool {

    if word.chars().all(|x| x.is_lowercase()) // 
        || word.chars().all(|x| x.is_uppercase())
        || word.chars().enumerate().all(|(id, x)| {
            if id == 0 {
                x.is_uppercase()
            } else {
                x.is_lowercase()
            }
        })
    {
        return true;
    }

    false
}

另外 word.chars().enumerate()也可以直接简写为 word.char_indices(),所以上面代码进一步可写为

pub fn detect_capital_use(word: String) -> bool {

    if word.chars().all(|x| x.is_lowercase()) // 
        || word.chars().all(|x| x.is_uppercase())
        || word.char_indices().all(|(id, x)| {
            if id == 0 {
                x.is_uppercase()
            } else {
                x.is_lowercase()
            }
        })
    {
        return true;
    }

    false
}

三次遍历变为一次遍历
上面的算法对于首字母大写的情况会导致三次遍历,进一步可以优化

  • 判断除首字母之外剩下的字母是否都是小写或都是大写,具体实现:判断和第二个字母的大小写是否一致
  • 如果剩下的大小写不一样,直接返回false
  • 然后再根据第一个字母,第二个字母判断
  • 如果第一个字母是大写,返回true 。对于都是大写,或首字母大写
  • 如果如果第一个字母和第二个字母都是小写,返回true。对应都是小写的请。
  • 返回false

代码如下:

pub fn detect_capital_use(word: String) -> bool {

    let mut word = word.chars();
    let first = word.next();
    if first.is_none(){
        return true
    }
    let first = first.unwrap();

    if let Some(second) = word.next(){
        let res = word.try_for_each(move |x|{
            if second.is_lowercase() && x.is_lowercase(){
                return Ok(())
            }

            if second.is_uppercase() && x.is_uppercase(){
                return Ok(())
            }

            Err(())
        });

        if res.is_err(){
            return false
        }
        if first.is_uppercase(){
            return true
        }

        if first.is_lowercase() && second.is_lowercase(){
            return true
        }

        false
    }else{
        true
    }

}

这里代码中使用了try_for_each,官方描述:
An iterator method that applies a fallible function to each item in the iterator, stopping at the first error and returning that error. 翻译过来就是对迭代器中的每一个元素执行一个可能会出错的函数,如果函数返回错误就立刻停止迭代并返回该错误。 其中 try_for_each的源码如下,可以看到内部也是调用了try_fold实现

fn try_for_each(&mut self, f: F) -> R
    where
        Self: Sized,
        F: FnMut(Self::Item) -> R,
        R: Try,
    {
        #[inline]
        fn call(mut f: impl FnMut(T) -> R) -> impl FnMut((), T) -> R {
            move |(), x| f(x)
        }

        self.try_fold((), call(f))
    }

所以上面代码也可以把try_for_each直接使用 try_fold来写,如下:

pub fn detect_capital_use(word: String) -> bool {
    let mut word = word.chars();
    let first = word.next();
    if first.is_none() {
        return true;
    }
    let first = first.unwrap();

    if let Some(second) = word.next() {
        let res = word.try_fold(second, move |sd, x| {
            if sd.is_lowercase() && x.is_lowercase() {
                return Ok(sd);
            }

            if sd.is_uppercase() && x.is_uppercase() {
                return Ok(sd);
            }

            Err(())
        });

        if res.is_err() {
            return false;
        }
        if first.is_uppercase() {
            return true;
        }

        if first.is_lowercase() && second.is_lowercase() {
            return true;
        }

        false
    } else {
        true
    }
}

总结
在本文中可以看到虽然是一个简单的算法例子,利用Rust语言中迭代器以及相关的几个方法chars,all,try_fold,enumerate,try_for_each,char_indices`,可以看到Rust语言的灵活性和强大。