Merge tag 'ext4_for_linus_stable' of git://git.kernel.org/pub/scm/linux/kernel/git...
[firefly-linux-kernel-4.4.55.git] / arch / x86 / kvm / mmu.c
index 44a7d25154973437e0ce4233e142d01c43a948b9..f807496b62c2cc76e82a60cd58ee187f0cdc77c2 100644 (file)
@@ -223,15 +223,15 @@ static unsigned int get_mmio_spte_generation(u64 spte)
        return gen;
 }
 
-static unsigned int kvm_current_mmio_generation(struct kvm *kvm)
+static unsigned int kvm_current_mmio_generation(struct kvm_vcpu *vcpu)
 {
-       return kvm_memslots(kvm)->generation & MMIO_GEN_MASK;
+       return kvm_vcpu_memslots(vcpu)->generation & MMIO_GEN_MASK;
 }
 
-static void mark_mmio_spte(struct kvm *kvm, u64 *sptep, u64 gfn,
+static void mark_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, u64 gfn,
                           unsigned access)
 {
-       unsigned int gen = kvm_current_mmio_generation(kvm);
+       unsigned int gen = kvm_current_mmio_generation(vcpu);
        u64 mask = generation_mmio_spte_mask(gen);
 
        access &= ACC_WRITE_MASK | ACC_USER_MASK;
@@ -258,22 +258,22 @@ static unsigned get_mmio_spte_access(u64 spte)
        return (spte & ~mask) & ~PAGE_MASK;
 }
 
-static bool set_mmio_spte(struct kvm *kvm, u64 *sptep, gfn_t gfn,
+static bool set_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, gfn_t gfn,
                          pfn_t pfn, unsigned access)
 {
        if (unlikely(is_noslot_pfn(pfn))) {
-               mark_mmio_spte(kvm, sptep, gfn, access);
+               mark_mmio_spte(vcpu, sptep, gfn, access);
                return true;
        }
 
        return false;
 }
 
-static bool check_mmio_spte(struct kvm *kvm, u64 spte)
+static bool check_mmio_spte(struct kvm_vcpu *vcpu, u64 spte)
 {
        unsigned int kvm_gen, spte_gen;
 
-       kvm_gen = kvm_current_mmio_generation(kvm);
+       kvm_gen = kvm_current_mmio_generation(vcpu);
        spte_gen = get_mmio_spte_generation(spte);
 
        trace_check_mmio_spte(spte, kvm_gen, spte_gen);
@@ -804,30 +804,36 @@ static struct kvm_lpage_info *lpage_info_slot(gfn_t gfn,
        return &slot->arch.lpage_info[level - 2][idx];
 }
 
-static void account_shadowed(struct kvm *kvm, gfn_t gfn)
+static void account_shadowed(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
+       struct kvm_memslots *slots;
        struct kvm_memory_slot *slot;
        struct kvm_lpage_info *linfo;
+       gfn_t gfn;
        int i;
 
-       slot = gfn_to_memslot(kvm, gfn);
-       for (i = PT_DIRECTORY_LEVEL;
-            i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
+       gfn = sp->gfn;
+       slots = kvm_memslots_for_spte_role(kvm, sp->role);
+       slot = __gfn_to_memslot(slots, gfn);
+       for (i = PT_DIRECTORY_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
                linfo = lpage_info_slot(gfn, slot, i);
                linfo->write_count += 1;
        }
        kvm->arch.indirect_shadow_pages++;
 }
 
-static void unaccount_shadowed(struct kvm *kvm, gfn_t gfn)
+static void unaccount_shadowed(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
+       struct kvm_memslots *slots;
        struct kvm_memory_slot *slot;
        struct kvm_lpage_info *linfo;
+       gfn_t gfn;
        int i;
 
-       slot = gfn_to_memslot(kvm, gfn);
-       for (i = PT_DIRECTORY_LEVEL;
-            i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
+       gfn = sp->gfn;
+       slots = kvm_memslots_for_spte_role(kvm, sp->role);
+       slot = __gfn_to_memslot(slots, gfn);
+       for (i = PT_DIRECTORY_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
                linfo = lpage_info_slot(gfn, slot, i);
                linfo->write_count -= 1;
                WARN_ON(linfo->write_count < 0);
@@ -835,14 +841,14 @@ static void unaccount_shadowed(struct kvm *kvm, gfn_t gfn)
        kvm->arch.indirect_shadow_pages--;
 }
 
-static int has_wrprotected_page(struct kvm *kvm,
+static int has_wrprotected_page(struct kvm_vcpu *vcpu,
                                gfn_t gfn,
                                int level)
 {
        struct kvm_memory_slot *slot;
        struct kvm_lpage_info *linfo;
 
-       slot = gfn_to_memslot(kvm, gfn);
+       slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
        if (slot) {
                linfo = lpage_info_slot(gfn, slot, level);
                return linfo->write_count;
@@ -858,8 +864,7 @@ static int host_mapping_level(struct kvm *kvm, gfn_t gfn)
 
        page_size = kvm_host_page_size(kvm, gfn);
 
-       for (i = PT_PAGE_TABLE_LEVEL;
-            i < (PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES); ++i) {
+       for (i = PT_PAGE_TABLE_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
                if (page_size >= KVM_HPAGE_SIZE(i))
                        ret = i;
                else
@@ -875,7 +880,7 @@ gfn_to_memslot_dirty_bitmap(struct kvm_vcpu *vcpu, gfn_t gfn,
 {
        struct kvm_memory_slot *slot;
 
-       slot = gfn_to_memslot(vcpu->kvm, gfn);
+       slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
        if (!slot || slot->flags & KVM_MEMSLOT_INVALID ||
              (no_dirty_log && slot->dirty_bitmap))
                slot = NULL;
@@ -900,7 +905,7 @@ static int mapping_level(struct kvm_vcpu *vcpu, gfn_t large_gfn)
        max_level = min(kvm_x86_ops->get_lpage_level(), host_level);
 
        for (level = PT_DIRECTORY_LEVEL; level <= max_level; ++level)
-               if (has_wrprotected_page(vcpu->kvm, large_gfn, level))
+               if (has_wrprotected_page(vcpu, large_gfn, level))
                        break;
 
        return level - 1;
@@ -1042,12 +1047,14 @@ static unsigned long *__gfn_to_rmap(gfn_t gfn, int level,
 /*
  * Take gfn and return the reverse mapping to it.
  */
-static unsigned long *gfn_to_rmap(struct kvm *kvm, gfn_t gfn, int level)
+static unsigned long *gfn_to_rmap(struct kvm *kvm, gfn_t gfn, struct kvm_mmu_page *sp)
 {
+       struct kvm_memslots *slots;
        struct kvm_memory_slot *slot;
 
-       slot = gfn_to_memslot(kvm, gfn);
-       return __gfn_to_rmap(gfn, level, slot);
+       slots = kvm_memslots_for_spte_role(kvm, sp->role);
+       slot = __gfn_to_memslot(slots, gfn);
+       return __gfn_to_rmap(gfn, sp->role.level, slot);
 }
 
 static bool rmap_can_add(struct kvm_vcpu *vcpu)
@@ -1065,7 +1072,7 @@ static int rmap_add(struct kvm_vcpu *vcpu, u64 *spte, gfn_t gfn)
 
        sp = page_header(__pa(spte));
        kvm_mmu_page_set_gfn(sp, spte - sp->spt, gfn);
-       rmapp = gfn_to_rmap(vcpu->kvm, gfn, sp->role.level);
+       rmapp = gfn_to_rmap(vcpu->kvm, gfn, sp);
        return pte_list_add(vcpu, spte, rmapp);
 }
 
@@ -1077,7 +1084,7 @@ static void rmap_remove(struct kvm *kvm, u64 *spte)
 
        sp = page_header(__pa(spte));
        gfn = kvm_mmu_page_get_gfn(sp, spte - sp->spt);
-       rmapp = gfn_to_rmap(kvm, gfn, sp->role.level);
+       rmapp = gfn_to_rmap(kvm, gfn, sp);
        pte_list_remove(spte, rmapp);
 }
 
@@ -1142,6 +1149,11 @@ static u64 *rmap_get_next(struct rmap_iterator *iter)
        return NULL;
 }
 
+#define for_each_rmap_spte(_rmap_, _iter_, _spte_)                         \
+          for (_spte_ = rmap_get_first(*_rmap_, _iter_);                   \
+               _spte_ && ({BUG_ON(!is_shadow_present_pte(*_spte_)); 1;});  \
+                       _spte_ = rmap_get_next(_iter_))
+
 static void drop_spte(struct kvm *kvm, u64 *sptep)
 {
        if (mmu_spte_clear_track_bits(sptep))
@@ -1205,12 +1217,8 @@ static bool __rmap_write_protect(struct kvm *kvm, unsigned long *rmapp,
        struct rmap_iterator iter;
        bool flush = false;
 
-       for (sptep = rmap_get_first(*rmapp, &iter); sptep;) {
-               BUG_ON(!(*sptep & PT_PRESENT_MASK));
-
+       for_each_rmap_spte(rmapp, &iter, sptep)
                flush |= spte_write_protect(kvm, sptep, pt_protect);
-               sptep = rmap_get_next(&iter);
-       }
 
        return flush;
 }
@@ -1232,12 +1240,8 @@ static bool __rmap_clear_dirty(struct kvm *kvm, unsigned long *rmapp)
        struct rmap_iterator iter;
        bool flush = false;
 
-       for (sptep = rmap_get_first(*rmapp, &iter); sptep;) {
-               BUG_ON(!(*sptep & PT_PRESENT_MASK));
-
+       for_each_rmap_spte(rmapp, &iter, sptep)
                flush |= spte_clear_dirty(kvm, sptep);
-               sptep = rmap_get_next(&iter);
-       }
 
        return flush;
 }
@@ -1259,12 +1263,8 @@ static bool __rmap_set_dirty(struct kvm *kvm, unsigned long *rmapp)
        struct rmap_iterator iter;
        bool flush = false;
 
-       for (sptep = rmap_get_first(*rmapp, &iter); sptep;) {
-               BUG_ON(!(*sptep & PT_PRESENT_MASK));
-
+       for_each_rmap_spte(rmapp, &iter, sptep)
                flush |= spte_set_dirty(kvm, sptep);
-               sptep = rmap_get_next(&iter);
-       }
 
        return flush;
 }
@@ -1342,42 +1342,45 @@ void kvm_arch_mmu_enable_log_dirty_pt_masked(struct kvm *kvm,
                kvm_mmu_write_protect_pt_masked(kvm, slot, gfn_offset, mask);
 }
 
-static bool rmap_write_protect(struct kvm *kvm, u64 gfn)
+static bool rmap_write_protect(struct kvm_vcpu *vcpu, u64 gfn)
 {
        struct kvm_memory_slot *slot;
        unsigned long *rmapp;
        int i;
        bool write_protected = false;
 
-       slot = gfn_to_memslot(kvm, gfn);
+       slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
 
-       for (i = PT_PAGE_TABLE_LEVEL;
-            i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
+       for (i = PT_PAGE_TABLE_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
                rmapp = __gfn_to_rmap(gfn, i, slot);
-               write_protected |= __rmap_write_protect(kvm, rmapp, true);
+               write_protected |= __rmap_write_protect(vcpu->kvm, rmapp, true);
        }
 
        return write_protected;
 }
 
-static int kvm_unmap_rmapp(struct kvm *kvm, unsigned long *rmapp,
-                          struct kvm_memory_slot *slot, gfn_t gfn, int level,
-                          unsigned long data)
+static bool kvm_zap_rmapp(struct kvm *kvm, unsigned long *rmapp)
 {
        u64 *sptep;
        struct rmap_iterator iter;
-       int need_tlb_flush = 0;
+       bool flush = false;
 
        while ((sptep = rmap_get_first(*rmapp, &iter))) {
                BUG_ON(!(*sptep & PT_PRESENT_MASK));
-               rmap_printk("kvm_rmap_unmap_hva: spte %p %llx gfn %llx (%d)\n",
-                            sptep, *sptep, gfn, level);
+               rmap_printk("%s: spte %p %llx.\n", __func__, sptep, *sptep);
 
                drop_spte(kvm, sptep);
-               need_tlb_flush = 1;
+               flush = true;
        }
 
-       return need_tlb_flush;
+       return flush;
+}
+
+static int kvm_unmap_rmapp(struct kvm *kvm, unsigned long *rmapp,
+                          struct kvm_memory_slot *slot, gfn_t gfn, int level,
+                          unsigned long data)
+{
+       return kvm_zap_rmapp(kvm, rmapp);
 }
 
 static int kvm_set_pte_rmapp(struct kvm *kvm, unsigned long *rmapp,
@@ -1394,8 +1397,8 @@ static int kvm_set_pte_rmapp(struct kvm *kvm, unsigned long *rmapp,
        WARN_ON(pte_huge(*ptep));
        new_pfn = pte_pfn(*ptep);
 
-       for (sptep = rmap_get_first(*rmapp, &iter); sptep;) {
-               BUG_ON(!is_shadow_present_pte(*sptep));
+restart:
+       for_each_rmap_spte(rmapp, &iter, sptep) {
                rmap_printk("kvm_set_pte_rmapp: spte %p %llx gfn %llx (%d)\n",
                             sptep, *sptep, gfn, level);
 
@@ -1403,7 +1406,7 @@ static int kvm_set_pte_rmapp(struct kvm *kvm, unsigned long *rmapp,
 
                if (pte_write(*ptep)) {
                        drop_spte(kvm, sptep);
-                       sptep = rmap_get_first(*rmapp, &iter);
+                       goto restart;
                } else {
                        new_spte = *sptep & ~PT64_BASE_ADDR_MASK;
                        new_spte |= (u64)new_pfn << PAGE_SHIFT;
@@ -1414,7 +1417,6 @@ static int kvm_set_pte_rmapp(struct kvm *kvm, unsigned long *rmapp,
 
                        mmu_spte_clear_track_bits(sptep);
                        mmu_spte_set(sptep, new_spte);
-                       sptep = rmap_get_next(&iter);
                }
        }
 
@@ -1424,6 +1426,74 @@ static int kvm_set_pte_rmapp(struct kvm *kvm, unsigned long *rmapp,
        return 0;
 }
 
+struct slot_rmap_walk_iterator {
+       /* input fields. */
+       struct kvm_memory_slot *slot;
+       gfn_t start_gfn;
+       gfn_t end_gfn;
+       int start_level;
+       int end_level;
+
+       /* output fields. */
+       gfn_t gfn;
+       unsigned long *rmap;
+       int level;
+
+       /* private field. */
+       unsigned long *end_rmap;
+};
+
+static void
+rmap_walk_init_level(struct slot_rmap_walk_iterator *iterator, int level)
+{
+       iterator->level = level;
+       iterator->gfn = iterator->start_gfn;
+       iterator->rmap = __gfn_to_rmap(iterator->gfn, level, iterator->slot);
+       iterator->end_rmap = __gfn_to_rmap(iterator->end_gfn, level,
+                                          iterator->slot);
+}
+
+static void
+slot_rmap_walk_init(struct slot_rmap_walk_iterator *iterator,
+                   struct kvm_memory_slot *slot, int start_level,
+                   int end_level, gfn_t start_gfn, gfn_t end_gfn)
+{
+       iterator->slot = slot;
+       iterator->start_level = start_level;
+       iterator->end_level = end_level;
+       iterator->start_gfn = start_gfn;
+       iterator->end_gfn = end_gfn;
+
+       rmap_walk_init_level(iterator, iterator->start_level);
+}
+
+static bool slot_rmap_walk_okay(struct slot_rmap_walk_iterator *iterator)
+{
+       return !!iterator->rmap;
+}
+
+static void slot_rmap_walk_next(struct slot_rmap_walk_iterator *iterator)
+{
+       if (++iterator->rmap <= iterator->end_rmap) {
+               iterator->gfn += (1UL << KVM_HPAGE_GFN_SHIFT(iterator->level));
+               return;
+       }
+
+       if (++iterator->level > iterator->end_level) {
+               iterator->rmap = NULL;
+               return;
+       }
+
+       rmap_walk_init_level(iterator, iterator->level);
+}
+
+#define for_each_slot_rmap_range(_slot_, _start_level_, _end_level_,   \
+          _start_gfn, _end_gfn, _iter_)                                \
+       for (slot_rmap_walk_init(_iter_, _slot_, _start_level_,         \
+                                _end_level_, _start_gfn, _end_gfn);    \
+            slot_rmap_walk_okay(_iter_);                               \
+            slot_rmap_walk_next(_iter_))
+
 static int kvm_handle_hva_range(struct kvm *kvm,
                                unsigned long start,
                                unsigned long end,
@@ -1435,48 +1505,36 @@ static int kvm_handle_hva_range(struct kvm *kvm,
                                               int level,
                                               unsigned long data))
 {
-       int j;
-       int ret = 0;
        struct kvm_memslots *slots;
        struct kvm_memory_slot *memslot;
+       struct slot_rmap_walk_iterator iterator;
+       int ret = 0;
+       int i;
 
-       slots = kvm_memslots(kvm);
-
-       kvm_for_each_memslot(memslot, slots) {
-               unsigned long hva_start, hva_end;
-               gfn_t gfn_start, gfn_end;
-
-               hva_start = max(start, memslot->userspace_addr);
-               hva_end = min(end, memslot->userspace_addr +
-                                       (memslot->npages << PAGE_SHIFT));
-               if (hva_start >= hva_end)
-                       continue;
-               /*
-                * {gfn(page) | page intersects with [hva_start, hva_end)} =
-                * {gfn_start, gfn_start+1, ..., gfn_end-1}.
-                */
-               gfn_start = hva_to_gfn_memslot(hva_start, memslot);
-               gfn_end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, memslot);
-
-               for (j = PT_PAGE_TABLE_LEVEL;
-                    j < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++j) {
-                       unsigned long idx, idx_end;
-                       unsigned long *rmapp;
-                       gfn_t gfn = gfn_start;
+       for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
+               slots = __kvm_memslots(kvm, i);
+               kvm_for_each_memslot(memslot, slots) {
+                       unsigned long hva_start, hva_end;
+                       gfn_t gfn_start, gfn_end;
 
+                       hva_start = max(start, memslot->userspace_addr);
+                       hva_end = min(end, memslot->userspace_addr +
+                                     (memslot->npages << PAGE_SHIFT));
+                       if (hva_start >= hva_end)
+                               continue;
                        /*
-                        * {idx(page_j) | page_j intersects with
-                        *  [hva_start, hva_end)} = {idx, idx+1, ..., idx_end}.
+                        * {gfn(page) | page intersects with [hva_start, hva_end)} =
+                        * {gfn_start, gfn_start+1, ..., gfn_end-1}.
                         */
-                       idx = gfn_to_index(gfn_start, memslot->base_gfn, j);
-                       idx_end = gfn_to_index(gfn_end - 1, memslot->base_gfn, j);
-
-                       rmapp = __gfn_to_rmap(gfn_start, j, memslot);
-
-                       for (; idx <= idx_end;
-                              ++idx, gfn += (1UL << KVM_HPAGE_GFN_SHIFT(j)))
-                               ret |= handler(kvm, rmapp++, memslot,
-                                              gfn, j, data);
+                       gfn_start = hva_to_gfn_memslot(hva_start, memslot);
+                       gfn_end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, memslot);
+
+                       for_each_slot_rmap_range(memslot, PT_PAGE_TABLE_LEVEL,
+                                                PT_MAX_HUGEPAGE_LEVEL,
+                                                gfn_start, gfn_end - 1,
+                                                &iterator)
+                               ret |= handler(kvm, iterator.rmap, memslot,
+                                              iterator.gfn, iterator.level, data);
                }
        }
 
@@ -1518,16 +1576,13 @@ static int kvm_age_rmapp(struct kvm *kvm, unsigned long *rmapp,
 
        BUG_ON(!shadow_accessed_mask);
 
-       for (sptep = rmap_get_first(*rmapp, &iter); sptep;
-            sptep = rmap_get_next(&iter)) {
-               BUG_ON(!is_shadow_present_pte(*sptep));
-
+       for_each_rmap_spte(rmapp, &iter, sptep)
                if (*sptep & shadow_accessed_mask) {
                        young = 1;
                        clear_bit((ffs(shadow_accessed_mask) - 1),
                                 (unsigned long *)sptep);
                }
-       }
+
        trace_kvm_age_page(gfn, level, slot, young);
        return young;
 }
@@ -1548,15 +1603,11 @@ static int kvm_test_age_rmapp(struct kvm *kvm, unsigned long *rmapp,
        if (!shadow_accessed_mask)
                goto out;
 
-       for (sptep = rmap_get_first(*rmapp, &iter); sptep;
-            sptep = rmap_get_next(&iter)) {
-               BUG_ON(!is_shadow_present_pte(*sptep));
-
+       for_each_rmap_spte(rmapp, &iter, sptep)
                if (*sptep & shadow_accessed_mask) {
                        young = 1;
                        break;
                }
-       }
 out:
        return young;
 }
@@ -1570,7 +1621,7 @@ static void rmap_recycle(struct kvm_vcpu *vcpu, u64 *spte, gfn_t gfn)
 
        sp = page_header(__pa(spte));
 
-       rmapp = gfn_to_rmap(vcpu->kvm, gfn, sp->role.level);
+       rmapp = gfn_to_rmap(vcpu->kvm, gfn, sp);
 
        kvm_unmap_rmapp(vcpu->kvm, rmapp, NULL, gfn, sp->role.level, 0);
        kvm_flush_remote_tlbs(vcpu->kvm);
@@ -1990,7 +2041,7 @@ static void mmu_sync_children(struct kvm_vcpu *vcpu,
                bool protected = false;
 
                for_each_sp(pages, sp, parents, i)
-                       protected |= rmap_write_protect(vcpu->kvm, sp->gfn);
+                       protected |= rmap_write_protect(vcpu, sp->gfn);
 
                if (protected)
                        kvm_flush_remote_tlbs(vcpu->kvm);
@@ -2088,12 +2139,12 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
        hlist_add_head(&sp->hash_link,
                &vcpu->kvm->arch.mmu_page_hash[kvm_page_table_hashfn(gfn)]);
        if (!direct) {
-               if (rmap_write_protect(vcpu->kvm, gfn))
+               if (rmap_write_protect(vcpu, gfn))
                        kvm_flush_remote_tlbs(vcpu->kvm);
                if (level > PT_PAGE_TABLE_LEVEL && need_sync)
                        kvm_sync_pages(vcpu, gfn);
 
-               account_shadowed(vcpu->kvm, gfn);
+               account_shadowed(vcpu->kvm, sp);
        }
        sp->mmu_valid_gen = vcpu->kvm->arch.mmu_valid_gen;
        init_shadow_page_table(sp);
@@ -2274,7 +2325,7 @@ static int kvm_mmu_prepare_zap_page(struct kvm *kvm, struct kvm_mmu_page *sp,
        kvm_mmu_unlink_parents(kvm, sp);
 
        if (!sp->role.invalid && !sp->role.direct)
-               unaccount_shadowed(kvm, sp->gfn);
+               unaccount_shadowed(kvm, sp);
 
        if (sp->unsync)
                kvm_unlink_unsync_page(kvm, sp);
@@ -2386,111 +2437,6 @@ int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn)
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_unprotect_page);
 
-/*
- * The function is based on mtrr_type_lookup() in
- * arch/x86/kernel/cpu/mtrr/generic.c
- */
-static int get_mtrr_type(struct mtrr_state_type *mtrr_state,
-                        u64 start, u64 end)
-{
-       int i;
-       u64 base, mask;
-       u8 prev_match, curr_match;
-       int num_var_ranges = KVM_NR_VAR_MTRR;
-
-       if (!mtrr_state->enabled)
-               return 0xFF;
-
-       /* Make end inclusive end, instead of exclusive */
-       end--;
-
-       /* Look in fixed ranges. Just return the type as per start */
-       if (mtrr_state->have_fixed && (start < 0x100000)) {
-               int idx;
-
-               if (start < 0x80000) {
-                       idx = 0;
-                       idx += (start >> 16);
-                       return mtrr_state->fixed_ranges[idx];
-               } else if (start < 0xC0000) {
-                       idx = 1 * 8;
-                       idx += ((start - 0x80000) >> 14);
-                       return mtrr_state->fixed_ranges[idx];
-               } else if (start < 0x1000000) {
-                       idx = 3 * 8;
-                       idx += ((start - 0xC0000) >> 12);
-                       return mtrr_state->fixed_ranges[idx];
-               }
-       }
-
-       /*
-        * Look in variable ranges
-        * Look of multiple ranges matching this address and pick type
-        * as per MTRR precedence
-        */
-       if (!(mtrr_state->enabled & 2))
-               return mtrr_state->def_type;
-
-       prev_match = 0xFF;
-       for (i = 0; i < num_var_ranges; ++i) {
-               unsigned short start_state, end_state;
-
-               if (!(mtrr_state->var_ranges[i].mask_lo & (1 << 11)))
-                       continue;
-
-               base = (((u64)mtrr_state->var_ranges[i].base_hi) << 32) +
-                      (mtrr_state->var_ranges[i].base_lo & PAGE_MASK);
-               mask = (((u64)mtrr_state->var_ranges[i].mask_hi) << 32) +
-                      (mtrr_state->var_ranges[i].mask_lo & PAGE_MASK);
-
-               start_state = ((start & mask) == (base & mask));
-               end_state = ((end & mask) == (base & mask));
-               if (start_state != end_state)
-                       return 0xFE;
-
-               if ((start & mask) != (base & mask))
-                       continue;
-
-               curr_match = mtrr_state->var_ranges[i].base_lo & 0xff;
-               if (prev_match == 0xFF) {
-                       prev_match = curr_match;
-                       continue;
-               }
-
-               if (prev_match == MTRR_TYPE_UNCACHABLE ||
-                   curr_match == MTRR_TYPE_UNCACHABLE)
-                       return MTRR_TYPE_UNCACHABLE;
-
-               if ((prev_match == MTRR_TYPE_WRBACK &&
-                    curr_match == MTRR_TYPE_WRTHROUGH) ||
-                   (prev_match == MTRR_TYPE_WRTHROUGH &&
-                    curr_match == MTRR_TYPE_WRBACK)) {
-                       prev_match = MTRR_TYPE_WRTHROUGH;
-                       curr_match = MTRR_TYPE_WRTHROUGH;
-               }
-
-               if (prev_match != curr_match)
-                       return MTRR_TYPE_UNCACHABLE;
-       }
-
-       if (prev_match != 0xFF)
-               return prev_match;
-
-       return mtrr_state->def_type;
-}
-
-u8 kvm_get_guest_memory_type(struct kvm_vcpu *vcpu, gfn_t gfn)
-{
-       u8 mtrr;
-
-       mtrr = get_mtrr_type(&vcpu->arch.mtrr_state, gfn << PAGE_SHIFT,
-                            (gfn << PAGE_SHIFT) + PAGE_SIZE);
-       if (mtrr == 0xfe || mtrr == 0xff)
-               mtrr = MTRR_TYPE_WRBACK;
-       return mtrr;
-}
-EXPORT_SYMBOL_GPL(kvm_get_guest_memory_type);
-
 static void __kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
 {
        trace_kvm_mmu_unsync_page(sp);
@@ -2541,7 +2487,7 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
        u64 spte;
        int ret = 0;
 
-       if (set_mmio_spte(vcpu->kvm, sptep, gfn, pfn, pte_access))
+       if (set_mmio_spte(vcpu, sptep, gfn, pfn, pte_access))
                return 0;
 
        spte = PT_PRESENT_MASK;
@@ -2578,7 +2524,7 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                 * be fixed if guest refault.
                 */
                if (level > PT_PAGE_TABLE_LEVEL &&
-                   has_wrprotected_page(vcpu->kvm, gfn, level))
+                   has_wrprotected_page(vcpu, gfn, level))
                        goto done;
 
                spte |= PT_WRITABLE_MASK | SPTE_MMU_WRITEABLE;
@@ -2602,7 +2548,7 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
        }
 
        if (pte_access & ACC_WRITE_MASK) {
-               mark_page_dirty(vcpu->kvm, gfn);
+               kvm_vcpu_mark_page_dirty(vcpu, gfn);
                spte |= shadow_dirty_mask;
        }
 
@@ -2692,15 +2638,17 @@ static int direct_pte_prefetch_many(struct kvm_vcpu *vcpu,
                                    u64 *start, u64 *end)
 {
        struct page *pages[PTE_PREFETCH_NUM];
+       struct kvm_memory_slot *slot;
        unsigned access = sp->role.access;
        int i, ret;
        gfn_t gfn;
 
        gfn = kvm_mmu_page_get_gfn(sp, start - sp->spt);
-       if (!gfn_to_memslot_dirty_bitmap(vcpu, gfn, access & ACC_WRITE_MASK))
+       slot = gfn_to_memslot_dirty_bitmap(vcpu, gfn, access & ACC_WRITE_MASK);
+       if (!slot)
                return -1;
 
-       ret = gfn_to_page_many_atomic(vcpu->kvm, gfn, pages, end - start);
+       ret = gfn_to_page_many_atomic(slot, gfn, pages, end - start);
        if (ret <= 0)
                return -1;
 
@@ -2818,7 +2766,7 @@ static int kvm_handle_bad_page(struct kvm_vcpu *vcpu, gfn_t gfn, pfn_t pfn)
                return 1;
 
        if (pfn == KVM_PFN_ERR_HWPOISON) {
-               kvm_send_hwpoison_signal(gfn_to_hva(vcpu->kvm, gfn), current);
+               kvm_send_hwpoison_signal(kvm_vcpu_gfn_to_hva(vcpu, gfn), current);
                return 0;
        }
 
@@ -2841,7 +2789,7 @@ static void transparent_hugepage_adjust(struct kvm_vcpu *vcpu,
        if (!is_error_noslot_pfn(pfn) && !kvm_is_reserved_pfn(pfn) &&
            level == PT_PAGE_TABLE_LEVEL &&
            PageTransCompound(pfn_to_page(pfn)) &&
-           !has_wrprotected_page(vcpu->kvm, gfn, PT_DIRECTORY_LEVEL)) {
+           !has_wrprotected_page(vcpu, gfn, PT_DIRECTORY_LEVEL)) {
                unsigned long mask;
                /*
                 * mmu_notifier_retry was successful and we hold the
@@ -2933,7 +2881,7 @@ fast_pf_fix_direct_spte(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
         * Compare with set_spte where instead shadow_dirty_mask is set.
         */
        if (cmpxchg64(sptep, spte, spte | PT_WRITABLE_MASK) == spte)
-               mark_page_dirty(vcpu->kvm, gfn);
+               kvm_vcpu_mark_page_dirty(vcpu, gfn);
 
        return true;
 }
@@ -3388,7 +3336,7 @@ int handle_mmio_page_fault_common(struct kvm_vcpu *vcpu, u64 addr, bool direct)
                gfn_t gfn = get_mmio_spte_gfn(spte);
                unsigned access = get_mmio_spte_access(spte);
 
-               if (!check_mmio_spte(vcpu->kvm, spte))
+               if (!check_mmio_spte(vcpu, spte))
                        return RET_MMIO_PF_INVALID;
 
                if (direct)
@@ -3460,7 +3408,7 @@ static int kvm_arch_setup_async_pf(struct kvm_vcpu *vcpu, gva_t gva, gfn_t gfn)
        arch.direct_map = vcpu->arch.mmu.direct_map;
        arch.cr3 = vcpu->arch.mmu.get_cr3(vcpu);
 
-       return kvm_setup_async_pf(vcpu, gva, gfn_to_hva(vcpu->kvm, gfn), &arch);
+       return kvm_setup_async_pf(vcpu, gva, kvm_vcpu_gfn_to_hva(vcpu, gfn), &arch);
 }
 
 static bool can_do_async_pf(struct kvm_vcpu *vcpu)
@@ -3475,10 +3423,12 @@ static bool can_do_async_pf(struct kvm_vcpu *vcpu)
 static bool try_async_pf(struct kvm_vcpu *vcpu, bool prefault, gfn_t gfn,
                         gva_t gva, pfn_t *pfn, bool write, bool *writable)
 {
+       struct kvm_memory_slot *slot;
        bool async;
 
-       *pfn = gfn_to_pfn_async(vcpu->kvm, gfn, &async, write, writable);
-
+       slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
+       async = false;
+       *pfn = __gfn_to_pfn_memslot(slot, gfn, false, &async, write, writable);
        if (!async)
                return false; /* *pfn has correct page already */
 
@@ -3492,11 +3442,20 @@ static bool try_async_pf(struct kvm_vcpu *vcpu, bool prefault, gfn_t gfn,
                        return true;
        }
 
-       *pfn = gfn_to_pfn_prot(vcpu->kvm, gfn, write, writable);
-
+       *pfn = __gfn_to_pfn_memslot(slot, gfn, false, NULL, write, writable);
        return false;
 }
 
+static bool
+check_hugepage_cache_consistency(struct kvm_vcpu *vcpu, gfn_t gfn, int level)
+{
+       int page_num = KVM_PAGES_PER_HPAGE(level);
+
+       gfn &= ~(page_num - 1);
+
+       return kvm_mtrr_check_gfn_range_consistency(vcpu, gfn, page_num);
+}
+
 static int tdp_page_fault(struct kvm_vcpu *vcpu, gva_t gpa, u32 error_code,
                          bool prefault)
 {
@@ -3522,9 +3481,17 @@ static int tdp_page_fault(struct kvm_vcpu *vcpu, gva_t gpa, u32 error_code,
        if (r)
                return r;
 
-       force_pt_level = mapping_level_dirty_bitmap(vcpu, gfn);
+       if (mapping_level_dirty_bitmap(vcpu, gfn) ||
+           !check_hugepage_cache_consistency(vcpu, gfn, PT_DIRECTORY_LEVEL))
+               force_pt_level = 1;
+       else
+               force_pt_level = 0;
+
        if (likely(!force_pt_level)) {
                level = mapping_level(vcpu, gfn);
+               if (level > PT_DIRECTORY_LEVEL &&
+                   !check_hugepage_cache_consistency(vcpu, gfn, level))
+                       level = PT_DIRECTORY_LEVEL;
                gfn &= ~(KVM_PAGES_PER_HPAGE(level) - 1);
        } else
                level = PT_PAGE_TABLE_LEVEL;
@@ -3590,7 +3557,7 @@ static void inject_page_fault(struct kvm_vcpu *vcpu,
        vcpu->arch.mmu.inject_page_fault(vcpu, fault);
 }
 
-static bool sync_mmio_spte(struct kvm *kvm, u64 *sptep, gfn_t gfn,
+static bool sync_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, gfn_t gfn,
                           unsigned access, int *nr_present)
 {
        if (unlikely(is_mmio_spte(*sptep))) {
@@ -3600,7 +3567,7 @@ static bool sync_mmio_spte(struct kvm *kvm, u64 *sptep, gfn_t gfn,
                }
 
                (*nr_present)++;
-               mark_mmio_spte(kvm, sptep, gfn, access);
+               mark_mmio_spte(vcpu, sptep, gfn, access);
                return true;
        }
 
@@ -3878,6 +3845,7 @@ static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
        struct kvm_mmu *context = &vcpu->arch.mmu;
 
        context->base_role.word = 0;
+       context->base_role.smm = is_smm(vcpu);
        context->page_fault = tdp_page_fault;
        context->sync_page = nonpaging_sync_page;
        context->invlpg = nonpaging_invlpg;
@@ -3939,6 +3907,7 @@ void kvm_init_shadow_mmu(struct kvm_vcpu *vcpu)
                = smep && !is_write_protection(vcpu);
        context->base_role.smap_andnot_wp
                = smap && !is_write_protection(vcpu);
+       context->base_role.smm = is_smm(vcpu);
 }
 EXPORT_SYMBOL_GPL(kvm_init_shadow_mmu);
 
@@ -4110,7 +4079,7 @@ static u64 mmu_pte_write_fetch_gpte(struct kvm_vcpu *vcpu, gpa_t *gpa,
                /* Handle a 32-bit guest writing two halves of a 64-bit gpte */
                *gpa &= ~(gpa_t)7;
                *bytes = 8;
-               r = kvm_read_guest(vcpu->kvm, *gpa, &gentry, 8);
+               r = kvm_vcpu_read_guest(vcpu, *gpa, &gentry, 8);
                if (r)
                        gentry = 0;
                new = (const u8 *)&gentry;
@@ -4215,13 +4184,14 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        u64 entry, gentry, *spte;
        int npte;
        bool remote_flush, local_flush, zap_page;
-       union kvm_mmu_page_role mask = (union kvm_mmu_page_role) {
-               .cr0_wp = 1,
-               .cr4_pae = 1,
-               .nxe = 1,
-               .smep_andnot_wp = 1,
-               .smap_andnot_wp = 1,
-       };
+       union kvm_mmu_page_role mask = { };
+
+       mask.cr0_wp = 1;
+       mask.cr4_pae = 1;
+       mask.nxe = 1;
+       mask.smep_andnot_wp = 1;
+       mask.smap_andnot_wp = 1;
+       mask.smm = 1;
 
        /*
         * If we don't have indirect shadow pages, it means no page is
@@ -4420,36 +4390,115 @@ void kvm_mmu_setup(struct kvm_vcpu *vcpu)
        init_kvm_mmu(vcpu);
 }
 
-void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
-                                     struct kvm_memory_slot *memslot)
+/* The return value indicates if tlb flush on all vcpus is needed. */
+typedef bool (*slot_level_handler) (struct kvm *kvm, unsigned long *rmap);
+
+/* The caller should hold mmu-lock before calling this function. */
+static bool
+slot_handle_level_range(struct kvm *kvm, struct kvm_memory_slot *memslot,
+                       slot_level_handler fn, int start_level, int end_level,
+                       gfn_t start_gfn, gfn_t end_gfn, bool lock_flush_tlb)
 {
-       gfn_t last_gfn;
-       int i;
+       struct slot_rmap_walk_iterator iterator;
        bool flush = false;
 
-       last_gfn = memslot->base_gfn + memslot->npages - 1;
+       for_each_slot_rmap_range(memslot, start_level, end_level, start_gfn,
+                       end_gfn, &iterator) {
+               if (iterator.rmap)
+                       flush |= fn(kvm, iterator.rmap);
 
-       spin_lock(&kvm->mmu_lock);
+               if (need_resched() || spin_needbreak(&kvm->mmu_lock)) {
+                       if (flush && lock_flush_tlb) {
+                               kvm_flush_remote_tlbs(kvm);
+                               flush = false;
+                       }
+                       cond_resched_lock(&kvm->mmu_lock);
+               }
+       }
+
+       if (flush && lock_flush_tlb) {
+               kvm_flush_remote_tlbs(kvm);
+               flush = false;
+       }
 
-       for (i = PT_PAGE_TABLE_LEVEL;
-            i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
-               unsigned long *rmapp;
-               unsigned long last_index, index;
+       return flush;
+}
+
+static bool
+slot_handle_level(struct kvm *kvm, struct kvm_memory_slot *memslot,
+                 slot_level_handler fn, int start_level, int end_level,
+                 bool lock_flush_tlb)
+{
+       return slot_handle_level_range(kvm, memslot, fn, start_level,
+                       end_level, memslot->base_gfn,
+                       memslot->base_gfn + memslot->npages - 1,
+                       lock_flush_tlb);
+}
+
+static bool
+slot_handle_all_level(struct kvm *kvm, struct kvm_memory_slot *memslot,
+                     slot_level_handler fn, bool lock_flush_tlb)
+{
+       return slot_handle_level(kvm, memslot, fn, PT_PAGE_TABLE_LEVEL,
+                                PT_MAX_HUGEPAGE_LEVEL, lock_flush_tlb);
+}
+
+static bool
+slot_handle_large_level(struct kvm *kvm, struct kvm_memory_slot *memslot,
+                       slot_level_handler fn, bool lock_flush_tlb)
+{
+       return slot_handle_level(kvm, memslot, fn, PT_PAGE_TABLE_LEVEL + 1,
+                                PT_MAX_HUGEPAGE_LEVEL, lock_flush_tlb);
+}
+
+static bool
+slot_handle_leaf(struct kvm *kvm, struct kvm_memory_slot *memslot,
+                slot_level_handler fn, bool lock_flush_tlb)
+{
+       return slot_handle_level(kvm, memslot, fn, PT_PAGE_TABLE_LEVEL,
+                                PT_PAGE_TABLE_LEVEL, lock_flush_tlb);
+}
 
-               rmapp = memslot->arch.rmap[i - PT_PAGE_TABLE_LEVEL];
-               last_index = gfn_to_index(last_gfn, memslot->base_gfn, i);
+void kvm_zap_gfn_range(struct kvm *kvm, gfn_t gfn_start, gfn_t gfn_end)
+{
+       struct kvm_memslots *slots;
+       struct kvm_memory_slot *memslot;
+       int i;
 
-               for (index = 0; index <= last_index; ++index, ++rmapp) {
-                       if (*rmapp)
-                               flush |= __rmap_write_protect(kvm, rmapp,
-                                               false);
+       spin_lock(&kvm->mmu_lock);
+       for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
+               slots = __kvm_memslots(kvm, i);
+               kvm_for_each_memslot(memslot, slots) {
+                       gfn_t start, end;
+
+                       start = max(gfn_start, memslot->base_gfn);
+                       end = min(gfn_end, memslot->base_gfn + memslot->npages);
+                       if (start >= end)
+                               continue;
 
-                       if (need_resched() || spin_needbreak(&kvm->mmu_lock))
-                               cond_resched_lock(&kvm->mmu_lock);
+                       slot_handle_level_range(kvm, memslot, kvm_zap_rmapp,
+                                               PT_PAGE_TABLE_LEVEL, PT_MAX_HUGEPAGE_LEVEL,
+                                               start, end - 1, true);
                }
        }
 
        spin_unlock(&kvm->mmu_lock);
+}
+
+static bool slot_rmap_write_protect(struct kvm *kvm, unsigned long *rmapp)
+{
+       return __rmap_write_protect(kvm, rmapp, false);
+}
+
+void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
+                                     struct kvm_memory_slot *memslot)
+{
+       bool flush;
+
+       spin_lock(&kvm->mmu_lock);
+       flush = slot_handle_all_level(kvm, memslot, slot_rmap_write_protect,
+                                     false);
+       spin_unlock(&kvm->mmu_lock);
 
        /*
         * kvm_mmu_slot_remove_write_access() and kvm_vm_ioctl_get_dirty_log()
@@ -4482,9 +4531,8 @@ static bool kvm_mmu_zap_collapsible_spte(struct kvm *kvm,
        pfn_t pfn;
        struct kvm_mmu_page *sp;
 
-       for (sptep = rmap_get_first(*rmapp, &iter); sptep;) {
-               BUG_ON(!(*sptep & PT_PRESENT_MASK));
-
+restart:
+       for_each_rmap_spte(rmapp, &iter, sptep) {
                sp = page_header(__pa(sptep));
                pfn = spte_to_pfn(*sptep);
 
@@ -4499,71 +4547,31 @@ static bool kvm_mmu_zap_collapsible_spte(struct kvm *kvm,
                        !kvm_is_reserved_pfn(pfn) &&
                        PageTransCompound(pfn_to_page(pfn))) {
                        drop_spte(kvm, sptep);
-                       sptep = rmap_get_first(*rmapp, &iter);
                        need_tlb_flush = 1;
-               } else
-                       sptep = rmap_get_next(&iter);
+                       goto restart;
+               }
        }
 
        return need_tlb_flush;
 }
 
 void kvm_mmu_zap_collapsible_sptes(struct kvm *kvm,
-                       struct kvm_memory_slot *memslot)
+                                  const struct kvm_memory_slot *memslot)
 {
-       bool flush = false;
-       unsigned long *rmapp;
-       unsigned long last_index, index;
-
+       /* FIXME: const-ify all uses of struct kvm_memory_slot.  */
        spin_lock(&kvm->mmu_lock);
-
-       rmapp = memslot->arch.rmap[0];
-       last_index = gfn_to_index(memslot->base_gfn + memslot->npages - 1,
-                               memslot->base_gfn, PT_PAGE_TABLE_LEVEL);
-
-       for (index = 0; index <= last_index; ++index, ++rmapp) {
-               if (*rmapp)
-                       flush |= kvm_mmu_zap_collapsible_spte(kvm, rmapp);
-
-               if (need_resched() || spin_needbreak(&kvm->mmu_lock)) {
-                       if (flush) {
-                               kvm_flush_remote_tlbs(kvm);
-                               flush = false;
-                       }
-                       cond_resched_lock(&kvm->mmu_lock);
-               }
-       }
-
-       if (flush)
-               kvm_flush_remote_tlbs(kvm);
-
+       slot_handle_leaf(kvm, (struct kvm_memory_slot *)memslot,
+                        kvm_mmu_zap_collapsible_spte, true);
        spin_unlock(&kvm->mmu_lock);
 }
 
 void kvm_mmu_slot_leaf_clear_dirty(struct kvm *kvm,
                                   struct kvm_memory_slot *memslot)
 {
-       gfn_t last_gfn;
-       unsigned long *rmapp;
-       unsigned long last_index, index;
-       bool flush = false;
-
-       last_gfn = memslot->base_gfn + memslot->npages - 1;
+       bool flush;
 
        spin_lock(&kvm->mmu_lock);
-
-       rmapp = memslot->arch.rmap[PT_PAGE_TABLE_LEVEL - 1];
-       last_index = gfn_to_index(last_gfn, memslot->base_gfn,
-                       PT_PAGE_TABLE_LEVEL);
-
-       for (index = 0; index <= last_index; ++index, ++rmapp) {
-               if (*rmapp)
-                       flush |= __rmap_clear_dirty(kvm, rmapp);
-
-               if (need_resched() || spin_needbreak(&kvm->mmu_lock))
-                       cond_resched_lock(&kvm->mmu_lock);
-       }
-
+       flush = slot_handle_leaf(kvm, memslot, __rmap_clear_dirty, false);
        spin_unlock(&kvm->mmu_lock);
 
        lockdep_assert_held(&kvm->slots_lock);
@@ -4582,31 +4590,11 @@ EXPORT_SYMBOL_GPL(kvm_mmu_slot_leaf_clear_dirty);
 void kvm_mmu_slot_largepage_remove_write_access(struct kvm *kvm,
                                        struct kvm_memory_slot *memslot)
 {
-       gfn_t last_gfn;
-       int i;
-       bool flush = false;
-
-       last_gfn = memslot->base_gfn + memslot->npages - 1;
+       bool flush;
 
        spin_lock(&kvm->mmu_lock);
-
-       for (i = PT_PAGE_TABLE_LEVEL + 1; /* skip rmap for 4K page */
-            i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
-               unsigned long *rmapp;
-               unsigned long last_index, index;
-
-               rmapp = memslot->arch.rmap[i - PT_PAGE_TABLE_LEVEL];
-               last_index = gfn_to_index(last_gfn, memslot->base_gfn, i);
-
-               for (index = 0; index <= last_index; ++index, ++rmapp) {
-                       if (*rmapp)
-                               flush |= __rmap_write_protect(kvm, rmapp,
-                                               false);
-
-                       if (need_resched() || spin_needbreak(&kvm->mmu_lock))
-                               cond_resched_lock(&kvm->mmu_lock);
-               }
-       }
+       flush = slot_handle_large_level(kvm, memslot, slot_rmap_write_protect,
+                                       false);
        spin_unlock(&kvm->mmu_lock);
 
        /* see kvm_mmu_slot_remove_write_access */
@@ -4620,31 +4608,10 @@ EXPORT_SYMBOL_GPL(kvm_mmu_slot_largepage_remove_write_access);
 void kvm_mmu_slot_set_dirty(struct kvm *kvm,
                            struct kvm_memory_slot *memslot)
 {
-       gfn_t last_gfn;
-       int i;
-       bool flush = false;
-
-       last_gfn = memslot->base_gfn + memslot->npages - 1;
+       bool flush;
 
        spin_lock(&kvm->mmu_lock);
-
-       for (i = PT_PAGE_TABLE_LEVEL;
-            i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
-               unsigned long *rmapp;
-               unsigned long last_index, index;
-
-               rmapp = memslot->arch.rmap[i - PT_PAGE_TABLE_LEVEL];
-               last_index = gfn_to_index(last_gfn, memslot->base_gfn, i);
-
-               for (index = 0; index <= last_index; ++index, ++rmapp) {
-                       if (*rmapp)
-                               flush |= __rmap_set_dirty(kvm, rmapp);
-
-                       if (need_resched() || spin_needbreak(&kvm->mmu_lock))
-                               cond_resched_lock(&kvm->mmu_lock);
-               }
-       }
-
+       flush = slot_handle_all_level(kvm, memslot, __rmap_set_dirty, false);
        spin_unlock(&kvm->mmu_lock);
 
        lockdep_assert_held(&kvm->slots_lock);
@@ -4741,13 +4708,13 @@ static bool kvm_has_zapped_obsolete_pages(struct kvm *kvm)
        return unlikely(!list_empty_careful(&kvm->arch.zapped_obsolete_pages));
 }
 
-void kvm_mmu_invalidate_mmio_sptes(struct kvm *kvm)
+void kvm_mmu_invalidate_mmio_sptes(struct kvm *kvm, struct kvm_memslots *slots)
 {
        /*
         * The very rare case: if the generation-number is round,
         * zap all shadow pages.
         */
-       if (unlikely(kvm_current_mmio_generation(kvm) == 0)) {
+       if (unlikely((slots->generation & MMIO_GEN_MASK) == 0)) {
                printk_ratelimited(KERN_DEBUG "kvm: zapping shadow pages for mmio generation wraparound\n");
                kvm_mmu_invalidate_zap_all_pages(kvm);
        }
@@ -4869,15 +4836,18 @@ unsigned int kvm_mmu_calculate_mmu_pages(struct kvm *kvm)
        unsigned int  nr_pages = 0;
        struct kvm_memslots *slots;
        struct kvm_memory_slot *memslot;
+       int i;
 
-       slots = kvm_memslots(kvm);
+       for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
+               slots = __kvm_memslots(kvm, i);
 
-       kvm_for_each_memslot(memslot, slots)
-               nr_pages += memslot->npages;
+               kvm_for_each_memslot(memslot, slots)
+                       nr_pages += memslot->npages;
+       }
 
        nr_mmu_pages = nr_pages * KVM_PERMILLE_MMU_PAGES / 1000;
        nr_mmu_pages = max(nr_mmu_pages,
-                       (unsigned int) KVM_MIN_ALLOC_MMU_PAGES);
+                          (unsigned int) KVM_MIN_ALLOC_MMU_PAGES);
 
        return nr_mmu_pages;
 }