代码量390行,能讲明白,代码如下:
`import json
import re
import sys
import tkinter as tk
from tkinter import filedialog, messagebox
预处理函数
def remove_comments(source_code):
source_code = re.sub(r'//.', '', source_code, flags=re.MULTILINE)
source_code = re.sub(r'/*.?*/', '', source_code, flags=re.DOTALL)
return source_code
def process_macros(code):
pattern = r'^\s*#define\s+.+'
code = re.sub(pattern, '', code, flags=re.MULTILINE)
return code
def normalize_whitespace(source_code):
source_code = re.sub(r'[ \t]+', ' ', source_code)
return source_code
def simplify_conditions(source_code):
pattern = r'if\s(\s([a-zA-Z_][a-zA-Z0-9_])\s==\strue\s)'
simplified = r'if (\g<1>)'
source_code = re.sub(pattern, simplified, source_code)
return source_code
def remove_redundant_parentheses(code):
pattern = r'(\s(\w+)\s)'
code = re.sub(pattern, r'(\1)', code)
return code
def extract_identifiers(code):
pattern = r'\b([a-zA-Z_][a-zA-Z0-9_]*)\b'
return re.findall(pattern, code)
def guess_headers(identifiers):
headers = set()
for identifier in identifiers:
if 'print' in identifier or 'scanf' in identifier:
headers.add('<stdio.h>')
elif 'malloc' in identifier or 'free' in identifier:
headers.add('<stdlib.h>')
return headers
def complete_headers(code):
identifiers = extract_identifiers(code)
needed_headers = guess_headers(identifiers)
header_directives = ['#include {}'.format(header) for header in needed_headers]
completed_code = '\n'.join(header_directives) + '\n' + code
return completed_code
主处理逻辑
dangerous_fun = {"gets": 1, "strcpy": 2, "strcat": 2, "fgets": 3, "strncpy": 3, "strncat": 3}
program = {}
variables = {}
stack = []
vulns = []
特征提取
def parse_c_file(file_content):
global program
program["main"] = {"variables": [], "instructions": []}
variables_section = program["main"]["variables"]
instructions_section = program["main"]["instructions"]
variable_pattern = re.compile(r'(?:char|int|long|short|float|double)\s+(\w+)\s*\[\s*(\d+)?\s*\];?')
call_pattern = re.compile(r'(\w+)\s*\(([^)]*)\)\s*;')
lines = file_content.split('\n')
address_counter = 0x400000 # Start address (can be adjusted as necessary)
for line in lines:
var_match = variable_pattern.search(line)
if var_match:
var_name = var_match.group(1)
var_size = var_match.group(2) or '1'
address_offset = len(variables_section) * 16 + 16
var_addr = f"rbp-{address_offset:x}" # Adjusted address format to lower case
variables_section.append({
"name": var_name,
"address": var_addr
})
print(f"Variable parsed: {var_name} at {var_addr}")
call_match = call_pattern.search(line)
if call_match:
fn_name = call_match.group(1)
if fn_name in dangerous_fun:
instructions_section.append({
"op": "call",
"args": {
"fnname": f"<{fn_name}@plt>"
},
"address": f"0x{address_counter:x}"
})
print(f"Function call parsed: {fn_name} at address 0x{address_counter:x}")
address_counter += 0x20 # Increment for next function address
def load_vars():
global stack
stack_dict = {"rbp": 0, "ret": 8}
for function in program:
variables[function] = []
for variable in program[function]['variables']:
variables[function].append(variable)
address_match = re.search(r'rbp-(.*)', variable['address'])
if address_match:
address = address_match.group(1)
stack_dict[variable['name']] = -int(address, 16)
print(f"Variable address loaded: {variable['name']} -> {variable['address']}")
stack = sorted(stack_dict.items(), key=lambda x: x[1])
print(f"Loaded stack: {stack}")
def getVarFromAddr(rbp_addr):
for fun in variables:
for var in variables[fun]:
print(f"Checking address {rbp_addr} against variable {var['name']} with address {var['address']}")
if var['address'].lower() == rbp_addr.lower():
return var
return None
def getStartingPoint(name):
for element in stack:
if element[0] == name:
return element[1]
return None
def updateStack(name, value):
for i in range(len(stack)):
if stack[i][0] == name:
aux = list(stack[i])
aux[1] = value
stack[i] = tuple(aux)
print(f"Updated stack: {stack}")
def simplify():
s = []
for fun in program:
for ins in program[fun]["instructions"]:
if ins["op"] == "call":
s.append(fun)
s.append(ins["args"]["fnname"])
s.append(ins["address"])
print(f"Instruction simplified: {fun} {ins['args']['fnname']} {ins['address']}")
print(f"Simplified instruction set: {s}")
return s
def process(simple):
global stack
for i in range(0, len(simple), 3):
fun = simple[i]
fnname = simple[i + 1]
addr = simple[i + 2]
if fnname == "<gets@plt>":
dest_var_addr = 'rbp-10'
dest_var = getVarFromAddr(dest_var_addr)
print(f"Processing gets@plt, dest_var: {dest_var}")
if not dest_var:
print(f"Error: No variable found for address {dest_var_addr}")
continue
starting_point = getStartingPoint(dest_var['name'])
print(f"Starting point for {dest_var['name']}: {starting_point}")
for element in stack:
if element[1] > starting_point:
vuln_name = getVulnName(element[1])
print(f"Adding vulnerability {vuln_name} for {dest_var['name']} overflow to {element[0]}")
addVuln(vuln_name, dest_var["name"], element[0], fnname, addr, fnname)
addVuln("SCORRUPTION", dest_var["name"], element[0], fnname, addr, fnname)
elif fnname == "<fgets@plt>":
dest_var_addr = 'rbp-10'
dest_var = getVarFromAddr(dest_var_addr)
print(f"Processing fgets@plt, dest_var: {dest_var}, buffer size: 100")
if not dest_var:
print(f"Error: No variable found for address {dest_var_addr}")
continue
starting_point = getStartingPoint(dest_var['name'])
print(f"Starting point for {dest_var['name']}: {starting_point}")
ending_point = starting_point + 100
print(f"Ending point for fgets: {ending_point}")
for element in stack:
if starting_point < element[1] and element[1] < ending_point:
if element[0] != dest_var['name']:
vuln_name = getVulnName(element[1])
print(f"Adding vulnerability {vuln_name} for {dest_var['name']} overflow to {element[0]}")
addVuln(vuln_name, dest_var["name"], element[0], fnname, addr, fnname)
if ending_point >= 16:
addVuln(getVulnName(ending_point), dest_var["name"], element[0], fnname, addr, fnname)
updateStack(dest_var['name'], ending_point)
elif fnname == "<strcpy@plt>" or fnname == "<strcat@plt>" or fnname == "<strncat@plt>":
dest_var_addr = 'rbp-20'
source_var_addr = 'rbp-30'
dest_var = getVarFromAddr(dest_var_addr)
source_var = getVarFromAddr(source_var_addr)
print(f"Processing {fnname}, dest_var: {dest_var}, source_var: {source_var}")
if not dest_var or not source_var:
print(f"Error: Missing variable for addresses dest: {dest_var_addr}, source: {source_var_addr}")
continue
starting_point = getStartingPoint(dest_var['name'])
print(f"Starting point for {dest_var['name']}: {starting_point}")
init_size = getStartingPoint(source_var['name'])
ending_point = starting_point + abs(init_size)
print(f"Ending point for {fnname}: {ending_point}")
for element in stack:
print(f"Checking element {element} in stack for overflow comparison")
if starting_point < element[1] and element[1] <= ending_point:
if element[0] != dest_var['name']:
vuln_name = getVulnName(element[1])
print(f"Adding vulnerability {vuln_name} for {dest_var['name']} overflow to {element[0]}")
addVuln(vuln_name, dest_var["name"], element[0], fnname, addr, fnname)
if ending_point >= 16:
addVuln(getVulnName(ending_point), dest_var["name"], element[0], fnname, addr, fnname)
updateStack(dest_var['name'], ending_point)
elif fnname == "<strncpy@plt>":
dest_var_addr = 'rbp-40'
source_var_addr = 'rbp-50'
n = 20
dest_var = getVarFromAddr(dest_var_addr)
source_var = getVarFromAddr(source_var_addr)
print(f"Processing strncpy@plt, dest_var: {dest_var}, source_var: {source_var}")
if not dest_var or not source_var:
print(f"Error: Missing variable for addresses dest: {dest_var_addr}, source: {source_var_addr}")
continue
starting_point = getStartingPoint(dest_var['name'])
print(f"Starting point for {dest_var['name']}: {starting_point}")
ending_point = starting_point + n
print(f"Ending point for strncpy: {ending_point}")
for element in stack:
print(f"Checking element {element} in stack for overflow comparison")
if starting_point < element[1] and element[1] < ending_point:
if element[0] != dest_var['name']:
vuln_name = getVulnName(element[1])
print(f"Adding vulnerability {vuln_name} for {dest_var['name']} overflow to {element[0]}")
addVuln(vuln_name, dest_var["name"], element[0], fnname, addr, fnname)
if ending_point >= 16:
addVuln(getVulnName(ending_point), dest_var["name"], element[0], fnname, addr, fnname)
updateStack(dest_var['name'], ending_point)
def getVulnName(ending_point):
if ending_point < 0:
return 'VAROVERFLOW'
if 0 <= ending_point < 8:
return 'RBPOVERFLOW'
if 8 <= ending_point < 16:
return 'RETOVERFLOW'
if ending_point >= 16:
return 'SCORRUPTION'
def addVuln(vulnerability, overflow_var, overflown_var, vuln_function, address, fnname):
v = {}
v['vulnerability'] = vulnerability
v['overflow_var'] = overflow_var
v['vuln_function'] = vuln_function
v['address'] = address
v['fnname'] = fnname[1:-5]
if vulnerability == 'VAROVERFLOW':
v['description'] = 'Overflow into adjacent variable.'
v['overflown_var'] = overflown_var
elif vulnerability == 'SCORRUPTION':
v['description'] = "Stack corruption beyond saved frame for main."
v['overflown_address'] = "rbp+0x10"
if vulnerability == 'RBPOVERFLOW':
v['description'] = "Overflow into saved base pointer (rbp)."
elif vulnerability == 'RETOVERFLOW':
v['description'] = "Overflow into return address."
print(f"Vulnerability added: {v}")
vulns.append(v)
def analyze_file(filepath):
global vulns
vulns = []
with open(filepath, 'r', encoding='utf-8') as infile:
file_content = infile.read()
# 预处理
file_content = remove_comments(file_content)
file_content = process_macros(file_content)
file_content = normalize_whitespace(file_content)
file_content = simplify_conditions(file_content)
file_content = remove_redundant_parentheses(file_content)
file_content = complete_headers(file_content)
print(f"Processed code:\n{file_content}\n")
parse_c_file(file_content)
print(f"Parsed program: {program}")
load_vars()
simple = simplify()
print(f"Simplified instructions: {simple}")
process(simple)
print(f"Detected vulnerabilities: {vulns}")
return vulns
class BufferOverflowAnalyzerGUI(tk.Tk):
def init(self):
super().init()
self.title("缓冲区溢出漏洞检测")
self.geometry("800x500")
label = tk.Label(self, text = "缓冲区溢出漏洞检测",font=('宋体', '20',))
label.pack()
label = tk.Label(self, text = "请选择待检测的代码文件",font=('宋体', '13',))
label.pack(pady=10)
file_frame = tk.Frame(self)
file_frame.pack(pady=10)
self.select_file_button = tk.Button(file_frame, text="选择文件", command=self.select_file)
self.select_file_button.pack(side=tk.LEFT)
self.file_path_label = tk.Label(file_frame, text="", wraplength=500)
self.file_path_label.pack(side=tk.LEFT, padx=10)
self.start_analysis_button = tk.Button(file_frame, text="开始检测", command=self.analyze_selected_file)
self.start_analysis_button.pack(side=tk.LEFT, padx=10)
self.result_text = tk.Text(self, height=20, wrap='word')
self.result_text.pack(padx=20, pady=20, expand=True, fill='both')
self.scrollbar = tk.Scrollbar(self.result_text)
self.scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
self.result_text.config(yscrollcommand=self.scrollbar.set)
self.scrollbar.config(command=self.result_text.yview)
self.selected_file = None
def select_file(self):
file_path = filedialog.askopenfilename(filetypes=[("C and C++ Files", "*.c *.cpp")])
if file_path:
self.selected_file = file_path
self.file_path_label.config(text=file_path)
def analyze_selected_file(self):
if not self.selected_file:
messagebox.showerror("错误", "没有选择文件!")
return
results = analyze_file(self.selected_file)
print(f"Analysis results: {results}")
self.display_results(results)
def display_results(self, results):
self.result_text.delete("1.0", tk.END)
if not results:
self.result_text.insert(tk.END, "未找到缓冲区溢出漏洞.\n")
return
for vuln in results:
self.result_text.insert(tk.END, f"漏洞类型: {vuln['vulnerability']}\n")
self.result_text.insert(tk.END, f"产生漏洞的函数: {vuln['vuln_function']}\n")
self.result_text.insert(tk.END, f"溢出变量: {vuln.get('overflow_var', 'N/A')}\n")
self.result_text.insert(tk.END, f"被覆盖的变量/超出地址: {vuln.get('overflown_var', vuln.get('overflown_address', 'N/A'))}\n")
self.result_text.insert(tk.END, f"地址: {vuln['address']}\n")
self.result_text.insert(tk.END, f"危险函数: {vuln['fnname']}\n")
self.result_text.insert(tk.END, f"描述: {vuln.get('description', 'No description available')}\n")
self.result_text.insert(tk.END, "-"*40 + "\n")
if name == "main":
app = BufferOverflowAnalyzerGUI()
app.mainloop()
`