2572.无平方子集计数
链接:2572.无平方子集计数
难度:Medium
标签:位运算、数组、数学、动态规划、状态压缩
简介:返回数组 nums 中 无平方 且 非空 的子集数目。
题解 1 - rust
- 编辑时间:2023-02-19
- 执行用时:20ms
- 内存消耗:4.4MB
- 编程语言:rust
- 解法介绍:同上。
impl Solution {
    pub fn square_free_subsets(nums: Vec<i32>) -> i32 {
        let MOD = 1000000000 + 7;
        let MAXK = 10;
        let prime = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29];
        let check = |num: i32| -> bool {
            for v in prime.iter() {
                if num % (*v as i32).pow(2) == 0 {
                    return true;
                }
            }
            false
        };
        let mut nums = nums
            .into_iter()
            .filter(|num| !check(*num))
            .collect::<Vec<i32>>();
        let n = nums.len();
        let mut dp = vec![vec![0; 1 << MAXK]; n + 1];
        dp[0][0] = 1;
        for i in 1..=n {
            let num = nums[i - 1];
            let mut mask = 0;
            for j in 0..(1 << MAXK) {
                dp[i][j] = dp[i - 1][j];
            }
            for i in 0..MAXK {
                if num % prime[i] == 0 {
                    mask |= 1 << i;
                }
            }
            for j in 0..(1 << MAXK) {
                if (mask & j) == 0 {
                    dp[i][mask | j] = (dp[i][mask | j] + dp[i - 1][j]) % MOD;
                }
            }
        }
        let mut ans = 0;
        for j in 0..(1 << MAXK) {
            ans = (ans + dp[n][j]) % MOD;
        }
        ans = (ans - 1 + MOD) % MOD;
        ans
    }
}
题解 2 - cpp
- 编辑时间:2023-02-19
- 执行用时:156ms
- 内存消耗:90.5MB
- 编程语言:cpp
- 解法介绍:状态压缩+dp,对于每个数字找前面所有可能不重合的数字。
class Solution {
    typedef long long ll;
    const int mod = 1e9 + 7;
    const int MAXK = 10;
    int prime[10] = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29};
    bool check(int num) {
        for (int i = 0; i < MAXK; i++) {
            if (num % (int)pow(prime[i], 2) == 0) return true;
        }
        return false;
    }
public:
    int squareFreeSubsets(vector<int>& nums) {
        nums = filter(nums);
        int n = nums.size(), ans = 0;
        vector<vector<ll>> dp(n + 1, vector<ll>(1 << MAXK, 0));
        dp[0][0] = 1;
        for (int i = 1; i <= n; i++) {
            int num = nums[i - 1], mask = 0;
            for (int j = 0; j < (1 << MAXK); j++) dp[i][j] = dp[i - 1][j];
            for (int i = 0; i < MAXK; i++)
                if (num % prime[i] == 0) mask |= (1 << i);
            for (int j = 0; j < (1 << MAXK); j++)
                if ((mask & j) == 0) dp[i][mask | j] = (dp[i][mask | j] + dp[i - 1][j]) % mod;
        }
        for (int j = 0; j < (1 << MAXK); j++) ans = (ans + dp[n][j]) % mod;
        ans = (ans - 1 + mod) % mod;
        return ans;
    }
    vector<int> filter(vector<int> &nums) {
        vector<int> res;
        for (auto &num : nums) {
            if (!check(num)) res.push_back(num);
        }
        return res;
    }
};
题解 3 - python
- 编辑时间:2023-02-19
- 执行用时:1628ms
- 内存消耗:28.6MB
- 编程语言:python
- 解法介绍:同上。
class Solution:
    def squareFreeSubsets(self, nums: List[int]) -> int:
        mod = 1e9 + 7
        MAXK = 10
        prime = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
        def check(num: int) -> bool:
            for i in range(MAXK):
                if num % pow(prime[i], 2) == 0:
                    return True
            return False
        def filter(nums: List[int]) -> List[int]:
            res = []
            for num in nums:
                if not check(num):
                    res.append(num)
            return res
        nums = filter(nums)
        n = len(nums)
        ans = 0
        dp = [[0] * (1 << MAXK) for _ in range(n+1)]
        dp[0][0] = 1
        for i in range(1, n+1):
            num = nums[i-1]
            mask = 0
            for j in range(1 << MAXK):
                dp[i][j] = dp[i-1][j]
            for k in range(MAXK):
                if num % prime[k] == 0:
                    mask |= (1 << k)
            for j in range(1 << MAXK):
                if (mask & j) == 0:
                    dp[i][mask | j] = (dp[i][mask | j] + dp[i-1][j]) % mod
        for j in range(1 << MAXK):
            ans = (ans + dp[n][j]) % mod
        ans = (ans - 1 + mod) % mod
        return int(ans)