[WebAssembly] Fix function return type printing
authorDerek Schuff <dschuff@google.com>
Mon, 16 Nov 2015 21:12:41 +0000 (21:12 +0000)
committerDerek Schuff <dschuff@google.com>
Mon, 16 Nov 2015 21:12:41 +0000 (21:12 +0000)
Summary:
Previously return type information for a function was derived from
return dag nodes. But this didn't work for dags with != return node. So
instead compute it directly from the LLVM function as is done for imports.

Differential Revision: http://reviews.llvm.org/D14593

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@253251 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h
test/CodeGen/WebAssembly/func.ll

index 0ab8888015e2dc3d696a4b2751a131f1ad6b95b2..0516e7a3bd850cb28bfe6454ffbaa5e7756bc7ef 100644 (file)
@@ -146,14 +146,38 @@ void WebAssemblyAsmPrinter::EmitJumpTableInfo() {
   // Nothing to do; jump tables are incorporated into the instruction stream.
 }
 
+static void ComputeLegalValueVTs(const Function &F,
+                                 const TargetMachine &TM,
+                                 Type *Ty,
+                                 SmallVectorImpl<MVT> &ValueVTs) {
+  const DataLayout& DL(F.getParent()->getDataLayout());
+  const WebAssemblyTargetLowering &TLI =
+      *TM.getSubtarget<WebAssemblySubtarget>(F).getTargetLowering();
+  SmallVector<EVT, 4> VTs;
+  ComputeValueVTs(TLI, DL, Ty, VTs);
+
+  for (EVT VT : VTs) {
+    unsigned NumRegs = TLI.getNumRegisters(F.getContext(), VT);
+    MVT RegisterVT = TLI.getRegisterType(F.getContext(), VT);
+    for (unsigned i = 0; i != NumRegs; ++i)
+      ValueVTs.push_back(RegisterVT);
+  }
+}
+
 void WebAssemblyAsmPrinter::EmitFunctionBodyStart() {
   SmallString<128> Str;
   raw_svector_ostream OS(Str);
 
   for (MVT VT : MFI->getParams())
     OS << "\t" ".param " << toString(VT) << '\n';
-  for (MVT VT : MFI->getResults())
-    OS << "\t" ".result " << toString(VT) << '\n';
+
+  SmallVector<MVT, 4> ResultVTs;
+  const Function &F(*MF->getFunction());
+  ComputeLegalValueVTs(F, TM, F.getReturnType(), ResultVTs);
+  // If the return type needs to be legalized it will get converted into
+  // passing a pointer.
+  if (ResultVTs.size() == 1)
+    OS << "\t" ".result " << toString(ResultVTs.front()) << '\n';
 
   bool FirstVReg = true;
   for (unsigned Idx = 0, IdxE = MRI->getNumVirtRegs(); Idx != IdxE; ++Idx) {
@@ -210,20 +234,7 @@ void WebAssemblyAsmPrinter::EmitInstruction(const MachineInstr *MI) {
   }
 }
 
-static void ComputeLegalValueVTs(LLVMContext &Context,
-                                 const WebAssemblyTargetLowering &TLI,
-                                 const DataLayout &DL, Type *Ty,
-                                 SmallVectorImpl<MVT> &ValueVTs) {
-  SmallVector<EVT, 4> VTs;
-  ComputeValueVTs(TLI, DL, Ty, VTs);
 
-  for (EVT VT : VTs) {
-    unsigned NumRegs = TLI.getNumRegisters(Context, VT);
-    MVT RegisterVT = TLI.getRegisterType(Context, VT);
-    for (unsigned i = 0; i != NumRegs; ++i)
-      ValueVTs.push_back(RegisterVT);
-  }
-}
 
 void WebAssemblyAsmPrinter::EmitEndOfAsmFile(Module &M) {
   const DataLayout &DL = M.getDataLayout();
@@ -248,8 +259,7 @@ void WebAssemblyAsmPrinter::EmitEndOfAsmFile(Module &M) {
       // passing a pointer.
       bool SawParam = false;
       SmallVector<MVT, 4> ResultVTs;
-      ComputeLegalValueVTs(M.getContext(), TLI, DL, F.getReturnType(),
-                           ResultVTs);
+      ComputeLegalValueVTs(F, TM, F.getReturnType(), ResultVTs);
       if (ResultVTs.size() > 1) {
         ResultVTs.clear();
         OS << " (param " << toString(TLI.getPointerTy(DL));
@@ -258,20 +268,20 @@ void WebAssemblyAsmPrinter::EmitEndOfAsmFile(Module &M) {
 
       for (const Argument &A : F.args()) {
         SmallVector<MVT, 4> ParamVTs;
-        ComputeLegalValueVTs(M.getContext(), TLI, DL, A.getType(), ParamVTs);
-        for (EVT VT : ParamVTs) {
+        ComputeLegalValueVTs(F, TM, A.getType(), ParamVTs);
+        for (MVT VT : ParamVTs) {
           if (!SawParam) {
             OS << " (param";
             SawParam = true;
           }
-          OS << ' ' << toString(VT.getSimpleVT());
+          OS << ' ' << toString(VT);
         }
       }
       if (SawParam)
         OS << ')';
 
-      for (EVT VT : ResultVTs)
-        OS << " (result " << toString(VT.getSimpleVT()) << ')';
+      for (MVT VT : ResultVTs)
+        OS << " (result " << toString(VT) << ')';
 
       OS << '\n';
     }
index 47eea934b8cd759f80dd89864488a22c3787d5e5..bae4f526723ec2d254ff0948852664d66a98cd15 100644 (file)
@@ -326,8 +326,6 @@ SDValue WebAssemblyTargetLowering::LowerReturn(
     const SmallVectorImpl<ISD::OutputArg> &Outs,
     const SmallVectorImpl<SDValue> &OutVals, SDLoc DL,
     SelectionDAG &DAG) const {
-  MachineFunction &MF = DAG.getMachineFunction();
-
   assert(Outs.size() <= 1 && "WebAssembly can only return up to one value");
   if (CallConv != CallingConv::C)
     fail(DL, DAG, "WebAssembly doesn't support non-C calling conventions");
@@ -352,7 +350,6 @@ SDValue WebAssemblyTargetLowering::LowerReturn(
       fail(DL, DAG, "WebAssembly hasn't implemented cons regs last results");
     if (!Out.IsFixed)
       fail(DL, DAG, "WebAssembly doesn't support non-fixed results yet");
-    MF.getInfo<WebAssemblyFunctionInfo>()->addResult(Out.VT);
   }
 
   return Chain;
index 6e8b1dcb7a5949a088e4a40daa1368da04ae58de..9c23412b3cc04b7c02caafd090cc7e8a97f535a0 100644 (file)
@@ -28,7 +28,6 @@ class WebAssemblyFunctionInfo final : public MachineFunctionInfo {
   MachineFunction &MF;
 
   std::vector<MVT> Params;
-  std::vector<MVT> Results;
 
   /// A mapping from CodeGen vreg index to WebAssembly register number.
   std::vector<unsigned> WARegs;
@@ -48,9 +47,6 @@ public:
   void addParam(MVT VT) { Params.push_back(VT); }
   const std::vector<MVT> &getParams() const { return Params; }
 
-  void addResult(MVT VT) { Results.push_back(VT); }
-  const std::vector<MVT> &getResults() const { return Results; }
-
   static const unsigned UnusedReg = -1u;
 
   void stackifyVReg(unsigned VReg) {
index 17d20760e8ce6c9bd1cb0e92303c5ce9c13c5c17..97fd9913641be05bbad69b099dc9cdf6d6c54686 100644 (file)
@@ -45,3 +45,24 @@ define i32 @f2(i32 %p1, float %p2) {
 define void @f3(i32 %p1, float %p2) {
   ret void
 }
+
+; CHECK-LABEL: f4:
+; CHECK-NEXT: .param i32{{$}}
+; CHECK-NEXT: .result i32{{$}}
+; CHECK-NEXT: .local
+define i32 @f4(i32 %x) {
+entry:
+   %c = trunc i32 %x to i1
+   br i1 %c, label %true, label %false
+true:
+   ret i32 0
+false:
+   ret i32 1
+}
+
+; CHECK-LABEL: f5:
+; CHECK-NEXT: .result f32{{$}}
+; CHECK-NEXT: unreachable
+define float @f5()  {
+ unreachable
+}