Const-correct and prevent a copy of a SmallPtrSet.
[oota-llvm.git] / lib / TableGen / Record.cpp
index c553a21c261e0a414afabfc9a8f4582ac7ba8d6b..cb21be7b7627ae17100ec3ae432c497bd711e3ac 100644 (file)
@@ -114,8 +114,21 @@ Init *BitRecTy::convertValue(IntInit *II) {
 
 Init *BitRecTy::convertValue(TypedInit *VI) {
   RecTy *Ty = VI->getType();
-  if (isa<BitRecTy>(Ty) || isa<BitsRecTy>(Ty) || isa<IntRecTy>(Ty))
+  if (isa<BitRecTy>(Ty))
     return VI;  // Accept variable if it is already of bit type!
+  if (auto *BitsTy = dyn_cast<BitsRecTy>(Ty))
+    // Accept only bits<1> expression.
+    return BitsTy->getNumBits() == 1 ? VI : nullptr;
+  // Ternary !if can be converted to bit, but only if both sides are
+  // convertible to a bit.
+  if (TernOpInit *TOI = dyn_cast<TernOpInit>(VI)) {
+    if (TOI->getOpcode() != TernOpInit::TernaryOp::IF)
+      return nullptr;
+    if (!TOI->getMHS()->convertInitializerTo(BitRecTy::get()) ||
+        !TOI->getRHS()->convertInitializerTo(BitRecTy::get()))
+      return nullptr;
+    return TOI;
+  }
   return nullptr;
 }
 
@@ -811,20 +824,14 @@ Init *UnOpInit::Fold(Record *CurRec, MultiClass *CurMultiClass) const {
   }
   case HEAD: {
     if (ListInit *LHSl = dyn_cast<ListInit>(LHS)) {
-      if (LHSl->getSize() == 0) {
-        assert(0 && "Empty list in car");
-        return nullptr;
-      }
+      assert(LHSl->getSize() != 0 && "Empty list in car");
       return LHSl->getElement(0);
     }
     break;
   }
   case TAIL: {
     if (ListInit *LHSl = dyn_cast<ListInit>(LHS)) {
-      if (LHSl->getSize() == 0) {
-        assert(0 && "Empty list in cdr");
-        return nullptr;
-      }
+      assert(LHSl->getSize() != 0 && "Empty list in cdr");
       // Note the +1.  We can't just pass the result of getValues()
       // directly.
       ArrayRef<Init *>::iterator begin = LHSl->getValues().begin()+1;
@@ -958,17 +965,21 @@ Init *BinOpInit::Fold(Record *CurRec, MultiClass *CurMultiClass) const {
     break;
   }
   case ADD:
+  case AND:
   case SHL:
   case SRA:
   case SRL: {
-    IntInit *LHSi = dyn_cast<IntInit>(LHS);
-    IntInit *RHSi = dyn_cast<IntInit>(RHS);
+    IntInit *LHSi =
+      dyn_cast_or_null<IntInit>(LHS->convertInitializerTo(IntRecTy::get()));
+    IntInit *RHSi =
+      dyn_cast_or_null<IntInit>(RHS->convertInitializerTo(IntRecTy::get()));
     if (LHSi && RHSi) {
       int64_t LHSv = LHSi->getValue(), RHSv = RHSi->getValue();
       int64_t Result;
       switch (getOpcode()) {
       default: llvm_unreachable("Bad opcode!");
       case ADD: Result = LHSv +  RHSv; break;
+      case AND: Result = LHSv &  RHSv; break;
       case SHL: Result = LHSv << RHSv; break;
       case SRA: Result = LHSv >> RHSv; break;
       case SRL: Result = (uint64_t)LHSv >> (uint64_t)RHSv; break;
@@ -995,6 +1006,7 @@ std::string BinOpInit::getAsString() const {
   switch (Opc) {
   case CONCAT: Result = "!con"; break;
   case ADD: Result = "!add"; break;
+  case AND: Result = "!and"; break;
   case SHL: Result = "!shl"; break;
   case SRA: Result = "!sra"; break;
   case SRL: Result = "!srl"; break;