a2a54da8590c8e2829cf3b3c4d49ba7a75b65039
[oota-llvm.git] / lib / Transforms / Utils / SymbolRewriter.cpp
1 //===- SymbolRewriter.cpp - Symbol Rewriter ---------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // SymbolRewriter is a LLVM pass which can rewrite symbols transparently within
11 // existing code.  It is implemented as a compiler pass and is configured via a
12 // YAML configuration file.
13 //
14 // The YAML configuration file format is as follows:
15 //
16 // RewriteMapFile := RewriteDescriptors
17 // RewriteDescriptors := RewriteDescriptor | RewriteDescriptors
18 // RewriteDescriptor := RewriteDescriptorType ':' '{' RewriteDescriptorFields '}'
19 // RewriteDescriptorFields := RewriteDescriptorField | RewriteDescriptorFields
20 // RewriteDescriptorField := FieldIdentifier ':' FieldValue ','
21 // RewriteDescriptorType := Identifier
22 // FieldIdentifier := Identifier
23 // FieldValue := Identifier
24 // Identifier := [0-9a-zA-Z]+
25 //
26 // Currently, the following descriptor types are supported:
27 //
28 // - function:          (function rewriting)
29 //      + Source        (original name of the function)
30 //      + Target        (explicit transformation)
31 //      + Transform     (pattern transformation)
32 //      + Naked         (boolean, whether the function is undecorated)
33 // - global variable:   (external linkage global variable rewriting)
34 //      + Source        (original name of externally visible variable)
35 //      + Target        (explicit transformation)
36 //      + Transform     (pattern transformation)
37 // - global alias:      (global alias rewriting)
38 //      + Source        (original name of the aliased name)
39 //      + Target        (explicit transformation)
40 //      + Transform     (pattern transformation)
41 //
42 // Note that source and exactly one of [Target, Transform] must be provided
43 //
44 // New rewrite descriptors can be created.  Addding a new rewrite descriptor
45 // involves:
46 //
47 //  a) extended the rewrite descriptor kind enumeration
48 //     (<anonymous>::RewriteDescriptor::RewriteDescriptorType)
49 //  b) implementing the new descriptor
50 //     (c.f. <anonymous>::ExplicitRewriteFunctionDescriptor)
51 //  c) extending the rewrite map parser
52 //     (<anonymous>::RewriteMapParser::parseEntry)
53 //
54 //  Specify to rewrite the symbols using the `-rewrite-symbols` option, and
55 //  specify the map file to use for the rewriting via the `-rewrite-map-file`
56 //  option.
57 //
58 //===----------------------------------------------------------------------===//
59
60 #define DEBUG_TYPE "symbol-rewriter"
61 #include "llvm/CodeGen/Passes.h"
62 #include "llvm/Pass.h"
63 #include "llvm/ADT/SmallString.h"
64 #include "llvm/IR/LegacyPassManager.h"
65 #include "llvm/Support/CommandLine.h"
66 #include "llvm/Support/Debug.h"
67 #include "llvm/Support/MemoryBuffer.h"
68 #include "llvm/Support/Regex.h"
69 #include "llvm/Support/SourceMgr.h"
70 #include "llvm/Support/YAMLParser.h"
71 #include "llvm/Support/raw_ostream.h"
72 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
73 #include "llvm/Transforms/Utils/SymbolRewriter.h"
74
75 using namespace llvm;
76 using namespace SymbolRewriter;
77
78 static cl::list<std::string> RewriteMapFiles("rewrite-map-file",
79                                              cl::desc("Symbol Rewrite Map"),
80                                              cl::value_desc("filename"));
81
82 static void rewriteComdat(Module &M, GlobalObject *GO,
83                           const std::string &Source,
84                           const std::string &Target) {
85   if (Comdat *CD = GO->getComdat()) {
86     auto &Comdats = M.getComdatSymbolTable();
87
88     Comdat *C = M.getOrInsertComdat(Target);
89     C->setSelectionKind(CD->getSelectionKind());
90     GO->setComdat(C);
91
92     Comdats.erase(Comdats.find(Source));
93   }
94 }
95
96 namespace {
97 template <RewriteDescriptor::Type DT, typename ValueType,
98           ValueType *(llvm::Module::*Get)(StringRef) const>
99 class ExplicitRewriteDescriptor : public RewriteDescriptor {
100 public:
101   const std::string Source;
102   const std::string Target;
103
104   ExplicitRewriteDescriptor(StringRef S, StringRef T, const bool Naked)
105       : RewriteDescriptor(DT), Source(Naked ? StringRef("\01" + S.str()) : S),
106         Target(T) {}
107
108   bool performOnModule(Module &M) override;
109
110   static bool classof(const RewriteDescriptor *RD) {
111     return RD->getType() == DT;
112   }
113 };
114
115 template <RewriteDescriptor::Type DT, typename ValueType,
116           ValueType *(llvm::Module::*Get)(StringRef) const>
117 bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) {
118   bool Changed = false;
119   if (ValueType *S = (M.*Get)(Source)) {
120     if (GlobalObject *GO = dyn_cast<GlobalObject>(S))
121       rewriteComdat(M, GO, Source, Target);
122
123     if (Value *T = (M.*Get)(Target))
124       S->setValueName(T->getValueName());
125     else
126       S->setName(Target);
127
128     Changed = true;
129   }
130   return Changed;
131 }
132
133 template <RewriteDescriptor::Type DT, typename ValueType,
134           ValueType *(llvm::Module::*Get)(StringRef) const,
135           iterator_range<typename iplist<ValueType>::iterator>
136           (llvm::Module::*Iterator)()>
137 class PatternRewriteDescriptor : public RewriteDescriptor {
138 public:
139   const std::string Pattern;
140   const std::string Transform;
141
142   PatternRewriteDescriptor(StringRef P, StringRef T)
143     : RewriteDescriptor(DT), Pattern(P), Transform(T) { }
144
145   bool performOnModule(Module &M) override;
146
147   static bool classof(const RewriteDescriptor *RD) {
148     return RD->getType() == DT;
149   }
150 };
151
152 template <RewriteDescriptor::Type DT, typename ValueType,
153           ValueType *(llvm::Module::*Get)(StringRef) const,
154           iterator_range<typename iplist<ValueType>::iterator>
155           (llvm::Module::*Iterator)()>
156 bool PatternRewriteDescriptor<DT, ValueType, Get, Iterator>::
157 performOnModule(Module &M) {
158   bool Changed = false;
159   for (auto &C : (M.*Iterator)()) {
160     std::string Error;
161
162     std::string Name = Regex(Pattern).sub(Transform, C.getName(), &Error);
163     if (!Error.empty())
164       report_fatal_error("unable to transforn " + C.getName() + " in " +
165                          M.getModuleIdentifier() + ": " + Error);
166
167     if (C.getName() == Name)
168       continue;
169
170     if (GlobalObject *GO = dyn_cast<GlobalObject>(&C))
171       rewriteComdat(M, GO, C.getName(), Name);
172
173     if (Value *V = (M.*Get)(Name))
174       C.setValueName(V->getValueName());
175     else
176       C.setName(Name);
177
178     Changed = true;
179   }
180   return Changed;
181 }
182
183 /// Represents a rewrite for an explicitly named (function) symbol.  Both the
184 /// source function name and target function name of the transformation are
185 /// explicitly spelt out.
186 typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::Function,
187                                   llvm::Function, &llvm::Module::getFunction>
188     ExplicitRewriteFunctionDescriptor;
189
190 /// Represents a rewrite for an explicitly named (global variable) symbol.  Both
191 /// the source variable name and target variable name are spelt out.  This
192 /// applies only to module level variables.
193 typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
194                                   llvm::GlobalVariable,
195                                   &llvm::Module::getGlobalVariable>
196     ExplicitRewriteGlobalVariableDescriptor;
197
198 /// Represents a rewrite for an explicitly named global alias.  Both the source
199 /// and target name are explicitly spelt out.
200 typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::NamedAlias,
201                                   llvm::GlobalAlias,
202                                   &llvm::Module::getNamedAlias>
203     ExplicitRewriteNamedAliasDescriptor;
204
205 /// Represents a rewrite for a regular expression based pattern for functions.
206 /// A pattern for the function name is provided and a transformation for that
207 /// pattern to determine the target function name create the rewrite rule.
208 typedef PatternRewriteDescriptor<RewriteDescriptor::Type::Function,
209                                  llvm::Function, &llvm::Module::getFunction,
210                                  &llvm::Module::functions>
211     PatternRewriteFunctionDescriptor;
212
213 /// Represents a rewrite for a global variable based upon a matching pattern.
214 /// Each global variable matching the provided pattern will be transformed as
215 /// described in the transformation pattern for the target.  Applies only to
216 /// module level variables.
217 typedef PatternRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
218                                  llvm::GlobalVariable,
219                                  &llvm::Module::getGlobalVariable,
220                                  &llvm::Module::globals>
221     PatternRewriteGlobalVariableDescriptor;
222
223 /// PatternRewriteNamedAliasDescriptor - represents a rewrite for global
224 /// aliases which match a given pattern.  The provided transformation will be
225 /// applied to each of the matching names.
226 typedef PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias,
227                                  llvm::GlobalAlias,
228                                  &llvm::Module::getNamedAlias,
229                                  &llvm::Module::aliases>
230     PatternRewriteNamedAliasDescriptor;
231 } // namespace
232
233 bool RewriteMapParser::parse(const std::string &MapFile,
234                              RewriteDescriptorList *DL) {
235   ErrorOr<std::unique_ptr<MemoryBuffer>> Mapping =
236       MemoryBuffer::getFile(MapFile);
237
238   if (!Mapping)
239     report_fatal_error("unable to read rewrite map '" + MapFile + "': " +
240                        Mapping.getError().message());
241
242   if (!parse(*Mapping, DL))
243     report_fatal_error("unable to parse rewrite map '" + MapFile + "'");
244
245   return true;
246 }
247
248 bool RewriteMapParser::parse(std::unique_ptr<MemoryBuffer> &MapFile,
249                              RewriteDescriptorList *DL) {
250   SourceMgr SM;
251   yaml::Stream YS(MapFile->getBuffer(), SM);
252
253   for (auto &Document : YS) {
254     yaml::MappingNode *DescriptorList;
255
256     // ignore empty documents
257     if (isa<yaml::NullNode>(Document.getRoot()))
258       continue;
259
260     DescriptorList = dyn_cast<yaml::MappingNode>(Document.getRoot());
261     if (!DescriptorList) {
262       YS.printError(Document.getRoot(), "DescriptorList node must be a map");
263       return false;
264     }
265
266     for (auto &Descriptor : *DescriptorList)
267       if (!parseEntry(YS, Descriptor, DL))
268         return false;
269   }
270
271   return true;
272 }
273
274 bool RewriteMapParser::parseEntry(yaml::Stream &YS, yaml::KeyValueNode &Entry,
275                                   RewriteDescriptorList *DL) {
276   yaml::ScalarNode *Key;
277   yaml::MappingNode *Value;
278   SmallString<32> KeyStorage;
279   StringRef RewriteType;
280
281   Key = dyn_cast<yaml::ScalarNode>(Entry.getKey());
282   if (!Key) {
283     YS.printError(Entry.getKey(), "rewrite type must be a scalar");
284     return false;
285   }
286
287   Value = dyn_cast<yaml::MappingNode>(Entry.getValue());
288   if (!Value) {
289     YS.printError(Entry.getValue(), "rewrite descriptor must be a map");
290     return false;
291   }
292
293   RewriteType = Key->getValue(KeyStorage);
294   if (RewriteType.equals("function"))
295     return parseRewriteFunctionDescriptor(YS, Key, Value, DL);
296   else if (RewriteType.equals("global variable"))
297     return parseRewriteGlobalVariableDescriptor(YS, Key, Value, DL);
298   else if (RewriteType.equals("global alias"))
299     return parseRewriteGlobalAliasDescriptor(YS, Key, Value, DL);
300
301   YS.printError(Entry.getKey(), "unknown rewrite type");
302   return false;
303 }
304
305 bool RewriteMapParser::
306 parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
307                                yaml::MappingNode *Descriptor,
308                                RewriteDescriptorList *DL) {
309   bool Naked = false;
310   std::string Source;
311   std::string Target;
312   std::string Transform;
313
314   for (auto &Field : *Descriptor) {
315     yaml::ScalarNode *Key;
316     yaml::ScalarNode *Value;
317     SmallString<32> KeyStorage;
318     SmallString<32> ValueStorage;
319     StringRef KeyValue;
320
321     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
322     if (!Key) {
323       YS.printError(Field.getKey(), "descriptor key must be a scalar");
324       return false;
325     }
326
327     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
328     if (!Value) {
329       YS.printError(Field.getValue(), "descriptor value must be a scalar");
330       return false;
331     }
332
333     KeyValue = Key->getValue(KeyStorage);
334     if (KeyValue.equals("source")) {
335       std::string Error;
336
337       Source = Value->getValue(ValueStorage);
338       if (!Regex(Source).isValid(Error)) {
339         YS.printError(Field.getKey(), "invalid regex: " + Error);
340         return false;
341       }
342     } else if (KeyValue.equals("target")) {
343       Target = Value->getValue(ValueStorage);
344     } else if (KeyValue.equals("transform")) {
345       Transform = Value->getValue(ValueStorage);
346     } else if (KeyValue.equals("naked")) {
347       std::string Undecorated;
348
349       Undecorated = Value->getValue(ValueStorage);
350       Naked = StringRef(Undecorated).lower() == "true" || Undecorated == "1";
351     } else {
352       YS.printError(Field.getKey(), "unknown key for function");
353       return false;
354     }
355   }
356
357   if (Transform.empty() == Target.empty()) {
358     YS.printError(Descriptor,
359                   "exactly one of transform or target must be specified");
360     return false;
361   }
362
363   // TODO see if there is a more elegant solution to selecting the rewrite
364   // descriptor type
365   if (!Target.empty())
366     DL->push_back(new ExplicitRewriteFunctionDescriptor(Source, Target, Naked));
367   else
368     DL->push_back(new PatternRewriteFunctionDescriptor(Source, Transform));
369
370   return true;
371 }
372
373 bool RewriteMapParser::
374 parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
375                                      yaml::MappingNode *Descriptor,
376                                      RewriteDescriptorList *DL) {
377   std::string Source;
378   std::string Target;
379   std::string Transform;
380
381   for (auto &Field : *Descriptor) {
382     yaml::ScalarNode *Key;
383     yaml::ScalarNode *Value;
384     SmallString<32> KeyStorage;
385     SmallString<32> ValueStorage;
386     StringRef KeyValue;
387
388     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
389     if (!Key) {
390       YS.printError(Field.getKey(), "descriptor Key must be a scalar");
391       return false;
392     }
393
394     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
395     if (!Value) {
396       YS.printError(Field.getValue(), "descriptor value must be a scalar");
397       return false;
398     }
399
400     KeyValue = Key->getValue(KeyStorage);
401     if (KeyValue.equals("source")) {
402       std::string Error;
403
404       Source = Value->getValue(ValueStorage);
405       if (!Regex(Source).isValid(Error)) {
406         YS.printError(Field.getKey(), "invalid regex: " + Error);
407         return false;
408       }
409     } else if (KeyValue.equals("target")) {
410       Target = Value->getValue(ValueStorage);
411     } else if (KeyValue.equals("transform")) {
412       Transform = Value->getValue(ValueStorage);
413     } else {
414       YS.printError(Field.getKey(), "unknown Key for Global Variable");
415       return false;
416     }
417   }
418
419   if (Transform.empty() == Target.empty()) {
420     YS.printError(Descriptor,
421                   "exactly one of transform or target must be specified");
422     return false;
423   }
424
425   if (!Target.empty())
426     DL->push_back(new ExplicitRewriteGlobalVariableDescriptor(Source, Target,
427                                                               /*Naked*/false));
428   else
429     DL->push_back(new PatternRewriteGlobalVariableDescriptor(Source,
430                                                              Transform));
431
432   return true;
433 }
434
435 bool RewriteMapParser::
436 parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
437                                   yaml::MappingNode *Descriptor,
438                                   RewriteDescriptorList *DL) {
439   std::string Source;
440   std::string Target;
441   std::string Transform;
442
443   for (auto &Field : *Descriptor) {
444     yaml::ScalarNode *Key;
445     yaml::ScalarNode *Value;
446     SmallString<32> KeyStorage;
447     SmallString<32> ValueStorage;
448     StringRef KeyValue;
449
450     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
451     if (!Key) {
452       YS.printError(Field.getKey(), "descriptor key must be a scalar");
453       return false;
454     }
455
456     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
457     if (!Value) {
458       YS.printError(Field.getValue(), "descriptor value must be a scalar");
459       return false;
460     }
461
462     KeyValue = Key->getValue(KeyStorage);
463     if (KeyValue.equals("source")) {
464       std::string Error;
465
466       Source = Value->getValue(ValueStorage);
467       if (!Regex(Source).isValid(Error)) {
468         YS.printError(Field.getKey(), "invalid regex: " + Error);
469         return false;
470       }
471     } else if (KeyValue.equals("target")) {
472       Target = Value->getValue(ValueStorage);
473     } else if (KeyValue.equals("transform")) {
474       Transform = Value->getValue(ValueStorage);
475     } else {
476       YS.printError(Field.getKey(), "unknown key for Global Alias");
477       return false;
478     }
479   }
480
481   if (Transform.empty() == Target.empty()) {
482     YS.printError(Descriptor,
483                   "exactly one of transform or target must be specified");
484     return false;
485   }
486
487   if (!Target.empty())
488     DL->push_back(new ExplicitRewriteNamedAliasDescriptor(Source, Target,
489                                                           /*Naked*/false));
490   else
491     DL->push_back(new PatternRewriteNamedAliasDescriptor(Source, Transform));
492
493   return true;
494 }
495
496 namespace {
497 class RewriteSymbols : public ModulePass {
498 public:
499   static char ID; // Pass identification, replacement for typeid
500
501   RewriteSymbols();
502   RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL);
503
504   bool runOnModule(Module &M) override;
505
506 private:
507   void loadAndParseMapFiles();
508
509   SymbolRewriter::RewriteDescriptorList Descriptors;
510 };
511
512 char RewriteSymbols::ID = 0;
513
514 RewriteSymbols::RewriteSymbols() : ModulePass(ID) {
515   initializeRewriteSymbolsPass(*PassRegistry::getPassRegistry());
516   loadAndParseMapFiles();
517 }
518
519 RewriteSymbols::RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL)
520     : ModulePass(ID) {
521   Descriptors.splice(Descriptors.begin(), DL);
522 }
523
524 bool RewriteSymbols::runOnModule(Module &M) {
525   bool Changed;
526
527   Changed = false;
528   for (auto &Descriptor : Descriptors)
529     Changed |= Descriptor.performOnModule(M);
530
531   return Changed;
532 }
533
534 void RewriteSymbols::loadAndParseMapFiles() {
535   const std::vector<std::string> MapFiles(RewriteMapFiles);
536   SymbolRewriter::RewriteMapParser parser;
537
538   for (const auto &MapFile : MapFiles)
539     parser.parse(MapFile, &Descriptors);
540 }
541 }
542
543 INITIALIZE_PASS(RewriteSymbols, "rewrite-symbols", "Rewrite Symbols", false,
544                 false)
545
546 ModulePass *llvm::createRewriteSymbolsPass() { return new RewriteSymbols(); }
547
548 ModulePass *
549 llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) {
550   return new RewriteSymbols(DL);
551 }