}
 }
 
+static void mmu_free_roots(struct kvm_vcpu *vcpu)
+{
+       int i;
+
+#ifdef CONFIG_X86_64
+       if (vcpu->mmu.shadow_root_level == PT64_ROOT_LEVEL) {
+               hpa_t root = vcpu->mmu.root_hpa;
+
+               ASSERT(VALID_PAGE(root));
+               release_pt_page_64(vcpu, root, PT64_ROOT_LEVEL);
+               vcpu->mmu.root_hpa = INVALID_PAGE;
+               return;
+       }
+#endif
+       for (i = 0; i < 4; ++i) {
+               hpa_t root = vcpu->mmu.pae_root[i];
+
+               ASSERT(VALID_PAGE(root));
+               root &= PT64_BASE_ADDR_MASK;
+               release_pt_page_64(vcpu, root, PT32E_ROOT_LEVEL - 1);
+               vcpu->mmu.pae_root[i] = INVALID_PAGE;
+       }
+       vcpu->mmu.root_hpa = INVALID_PAGE;
+}
+
+static void mmu_alloc_roots(struct kvm_vcpu *vcpu)
+{
+       int i;
+
+#ifdef CONFIG_X86_64
+       if (vcpu->mmu.shadow_root_level == PT64_ROOT_LEVEL) {
+               hpa_t root = vcpu->mmu.root_hpa;
+
+               ASSERT(!VALID_PAGE(root));
+               root = kvm_mmu_alloc_page(vcpu, NULL);
+               vcpu->mmu.root_hpa = root;
+               return;
+       }
+#endif
+       for (i = 0; i < 4; ++i) {
+               hpa_t root = vcpu->mmu.pae_root[i];
+
+               ASSERT(!VALID_PAGE(root));
+               root = kvm_mmu_alloc_page(vcpu, NULL);
+               vcpu->mmu.pae_root[i] = root | PT_PRESENT_MASK;
+       }
+       vcpu->mmu.root_hpa = __pa(vcpu->mmu.pae_root);
+}
+
 static void nonpaging_flush(struct kvm_vcpu *vcpu)
 {
        hpa_t root = vcpu->mmu.root_hpa;
 
        ++kvm_stat.tlb_flush;
        pgprintk("nonpaging_flush\n");
-       ASSERT(VALID_PAGE(root));
-       release_pt_page_64(vcpu, root, vcpu->mmu.shadow_root_level);
-       root = kvm_mmu_alloc_page(vcpu, NULL);
-       ASSERT(VALID_PAGE(root));
-       vcpu->mmu.root_hpa = root;
-       if (is_paging(vcpu))
-               root |= (vcpu->cr3 & (CR3_PCD_MASK | CR3_WPT_MASK));
+       mmu_free_roots(vcpu);
+       mmu_alloc_roots(vcpu);
        kvm_arch_ops->set_cr3(vcpu, root);
        kvm_arch_ops->tlb_flush(vcpu);
 }
 
 static void nonpaging_free(struct kvm_vcpu *vcpu)
 {
-       hpa_t root;
-
-       ASSERT(vcpu);
-       root = vcpu->mmu.root_hpa;
-       if (VALID_PAGE(root))
-               release_pt_page_64(vcpu, root, vcpu->mmu.shadow_root_level);
-       vcpu->mmu.root_hpa = INVALID_PAGE;
+       mmu_free_roots(vcpu);
 }
 
 static int nonpaging_init_context(struct kvm_vcpu *vcpu)
        context->free = nonpaging_free;
        context->root_level = PT32E_ROOT_LEVEL;
        context->shadow_root_level = PT32E_ROOT_LEVEL;
-       context->root_hpa = kvm_mmu_alloc_page(vcpu, NULL);
+       mmu_alloc_roots(vcpu);
        ASSERT(VALID_PAGE(context->root_hpa));
        kvm_arch_ops->set_cr3(vcpu, context->root_hpa);
        return 0;
 #include "paging_tmpl.h"
 #undef PTTYPE
 
-static int paging64_init_context(struct kvm_vcpu *vcpu)
+static int paging64_init_context_common(struct kvm_vcpu *vcpu, int level)
 {
        struct kvm_mmu *context = &vcpu->mmu;
 
        context->inval_page = paging_inval_page;
        context->gva_to_gpa = paging64_gva_to_gpa;
        context->free = paging_free;
-       context->root_level = PT64_ROOT_LEVEL;
-       context->shadow_root_level = PT64_ROOT_LEVEL;
-       context->root_hpa = kvm_mmu_alloc_page(vcpu, NULL);
+       context->root_level = level;
+       context->shadow_root_level = level;
+       mmu_alloc_roots(vcpu);
        ASSERT(VALID_PAGE(context->root_hpa));
        kvm_arch_ops->set_cr3(vcpu, context->root_hpa |
                    (vcpu->cr3 & (CR3_PCD_MASK | CR3_WPT_MASK)));
        return 0;
 }
 
+static int paging64_init_context(struct kvm_vcpu *vcpu)
+{
+       return paging64_init_context_common(vcpu, PT64_ROOT_LEVEL);
+}
+
 static int paging32_init_context(struct kvm_vcpu *vcpu)
 {
        struct kvm_mmu *context = &vcpu->mmu;
        context->free = paging_free;
        context->root_level = PT32_ROOT_LEVEL;
        context->shadow_root_level = PT32E_ROOT_LEVEL;
-       context->root_hpa = kvm_mmu_alloc_page(vcpu, NULL);
+       mmu_alloc_roots(vcpu);
        ASSERT(VALID_PAGE(context->root_hpa));
        kvm_arch_ops->set_cr3(vcpu, context->root_hpa |
                    (vcpu->cr3 & (CR3_PCD_MASK | CR3_WPT_MASK)));
 
 static int paging32E_init_context(struct kvm_vcpu *vcpu)
 {
-       int ret;
-
-       if ((ret = paging64_init_context(vcpu)))
-               return ret;
-
-       vcpu->mmu.root_level = PT32E_ROOT_LEVEL;
-       vcpu->mmu.shadow_root_level = PT32E_ROOT_LEVEL;
-       return 0;
+       return paging64_init_context_common(vcpu, PT32E_ROOT_LEVEL);
 }
 
 static int init_kvm_mmu(struct kvm_vcpu *vcpu)
                __free_page(pfn_to_page(page->page_hpa >> PAGE_SHIFT));
                page->page_hpa = INVALID_PAGE;
        }
+       free_page((unsigned long)vcpu->mmu.pae_root);
 }
 
 static int alloc_mmu_pages(struct kvm_vcpu *vcpu)
 {
+       struct page *page;
        int i;
 
        ASSERT(vcpu);
 
        for (i = 0; i < KVM_NUM_MMU_PAGES; i++) {
-               struct page *page;
                struct kvm_mmu_page *page_header = &vcpu->page_header_buf[i];
 
                INIT_LIST_HEAD(&page_header->link);
-               if ((page = alloc_page(GFP_KVM_MMU)) == NULL)
+               if ((page = alloc_page(GFP_KERNEL)) == NULL)
                        goto error_1;
                page->private = (unsigned long)page_header;
                page_header->page_hpa = (hpa_t)page_to_pfn(page) << PAGE_SHIFT;
                memset(__va(page_header->page_hpa), 0, PAGE_SIZE);
                list_add(&page_header->link, &vcpu->free_pages);
        }
+
+       /*
+        * When emulating 32-bit mode, cr3 is only 32 bits even on x86_64.
+        * Therefore we need to allocate shadow page tables in the first
+        * 4GB of memory, which happens to fit the DMA32 zone.
+        */
+       page = alloc_page(GFP_KERNEL | __GFP_DMA32);
+       if (!page)
+               goto error_1;
+       vcpu->mmu.pae_root = page_address(page);
+       for (i = 0; i < 4; ++i)
+               vcpu->mmu.pae_root[i] = INVALID_PAGE;
+
        return 0;
 
 error_1: