aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/vulkaninfo_generator.py67
1 files changed, 60 insertions, 7 deletions
diff --git a/scripts/vulkaninfo_generator.py b/scripts/vulkaninfo_generator.py
index 318ecd51..3846e345 100644
--- a/scripts/vulkaninfo_generator.py
+++ b/scripts/vulkaninfo_generator.py
@@ -92,6 +92,9 @@ flags_to_gen = ['VkSurfaceTransformFlagsKHR', 'VkCompositeAlphaFlagsKHR',
'VkDeviceGroupPresentModeFlagsKHR', 'VkFormatFeatureFlags', 'VkMemoryPropertyFlags', 'VkMemoryHeapFlags']
flags_strings_to_gen = ['VkQueueFlags']
+struct_comparisons_to_gen = ['VkSurfaceFormatKHR', 'VkSurfaceFormat2KHR', 'VkSurfaceCapabilitiesKHR',
+ 'VkSurfaceCapabilities2KHR', 'VkSurfaceCapabilities2EXT']
+
# iostream or custom outputter handles these types
predefined_types = ['char', 'VkBool32', 'uint32_t', 'uint8_t', 'int32_t',
'float', 'uint64_t', 'size_t', 'VkDeviceSize', 'VkConformanceVersionKHR']
@@ -177,6 +180,7 @@ class VulkanInfoGenerator(OutputGenerator):
self.flags = set()
self.bitmasks = set()
self.structures = set()
+ self.structs_to_comp = set()
self.all_structures = set()
self.types_to_gen = set()
@@ -216,6 +220,12 @@ class VulkanInfoGenerator(OutputGenerator):
types_to_gen = types_to_gen.union(
GatherTypesToGen(self.extension_sets[key]))
+ structs_to_comp = set()
+ for s in struct_comparisons_to_gen:
+ structs_to_comp.add(s)
+ structs_to_comp = structs_to_comp.union(
+ GatherTypesToGen(self.structs_to_comp))
+
self.enums = sorted(self.enums, key=operator.attrgetter('name'))
self.flags = sorted(self.flags, key=operator.attrgetter('name'))
self.bitmasks = sorted(self.bitmasks, key=operator.attrgetter('name'))
@@ -275,6 +285,14 @@ class VulkanInfoGenerator(OutputGenerator):
out += PrintChainIterator(key,
self.extension_sets[key], value.get('type'))
+ for s in self.all_structures:
+ if s.name in structs_to_comp:
+ out += PrintStructComparisonForwardDef(s)
+
+ for s in self.all_structures:
+ if s.name in structs_to_comp:
+ out += PrintStructComparison(s)
+
gen.write(out, file=self.outFile)
gen.OutputGenerator.endFile(self)
@@ -311,6 +329,10 @@ class VulkanInfoGenerator(OutputGenerator):
self.structures.add(VulkanStructure(
name, typeinfo.elem, self.constants, self.extTypes))
+ if typeinfo.elem.get('category') == 'struct' and name in struct_comparisons_to_gen:
+ self.structs_to_comp.add(VulkanStructure(
+ name, typeinfo.elem, self.constants, self.extTypes))
+
if typeinfo.elem.get('category') == 'struct':
self.all_structures.add(VulkanStructure(
name, typeinfo.elem, self.constants, self.extTypes))
@@ -328,13 +350,19 @@ class VulkanInfoGenerator(OutputGenerator):
def GatherTypesToGen(structures):
- types_to_gen = set()
- for s in structures:
- types_to_gen.add(s.name)
- for m in s.members:
- if m.typeID not in predefined_types and m.name not in ['sType', 'pNext']:
- types_to_gen.add(m.typeID)
- return types_to_gen
+ types = set()
+ added_stuff = True # repeat until no new types are added
+ while added_stuff == True:
+ added_stuff = False
+ for s in structures:
+ size = len(types)
+ types.add(s.name)
+ if len(types) != size:
+ added_stuff = True
+ for m in s.members:
+ if m.typeID not in predefined_types and m.name not in ['sType', 'pNext']:
+ types.add(m.typeID)
+ return types
def GetExtension(name, generator):
@@ -608,6 +636,31 @@ def PrintChainIterator(listName, structures, checkExtLoc):
return out
+def PrintStructComparisonForwardDef(structure):
+ out = ''
+ out += "bool operator==(const " + structure.name + \
+ " & a, const " + structure.name + " b);\n"
+ return out
+
+
+def PrintStructComparison(structure):
+ out = ''
+ out += "bool operator==(const " + structure.name + \
+ " & a, const " + structure.name + " b) {\n"
+ out += " return "
+ is_first = True
+ for m in structure.members:
+ if m.name not in ['sType', 'pNext']:
+ if not is_first:
+ out += "\n && "
+ else:
+ is_first = False
+ out += "a." + m.name + " == b." + m.name
+ out += ";\n"
+ out += "}\n"
+ return out
+
+
def isPow2(num):
return num != 0 and ((num & (num - 1)) == 0)