Add long integers
jump to
@@ -4,7 +4,8 @@ An attempt to write an (educational) C(17) compiler. Provides a Lexer, Parser and a Code Generator (x86\_64).
The compiler driver uses GCC's preprocessor and assembler/linker, to help turning the code into an executable. -Supports: +### Supports: + - Return - Unary operators (!, ~, -) - Binary arithmetic and logical operators (+, -, \*, /, %, <<, >>, &, |)@@ -15,16 +16,5 @@ - If Statements (and goto + labeled statements)
- Compound Statements - Loops and Switch Statements - Function calls -- ~~File scope~~ - -Here is an example of now compilable c code: -``` -int main(void) { - int a = 2 >> 1; - int b = 1; - b += 1; - int c = a++; - int d = --b; - return (a == 2 && b == 1 && c == 1 && d == 1); // returns 1 -} -``` +- File scope +- Long type
@@ -36,17 +36,29 @@ write!(f, " {}\n", i)?;
} Ok(()) } - TopLevel::StaticVariable(name, global, init) => { - if *init == 0 { - write_global(f, name, *global)?; - write!(f, " .bss\n")?; - write_alignment(f)?; - write!(f, "{}:\n .zero 4\n", name) - } else { - write_global(f, name, *global)?; - write!(f, " .data\n")?; - write_alignment(f)?; - write!(f, "{}:\n .long {}\n", name, init) + TopLevel::StaticVariable(name, global, alignment, init) => { + write_global(f, name, *global)?; + match init { + crate::frontend::type_check::StaticInit::IntInit(0) => { + write!(f, " .bss\n")?; + write!(f, " .align {}\n", alignment)?; + write!(f, "{}:\n .zero 4\n", name) + } + crate::frontend::type_check::StaticInit::LongInit(0) => { + write!(f, " .bss\n")?; + write!(f, " .align {}\n", alignment)?; + write!(f, "{}:\n .zero 8\n", name) + } + crate::frontend::type_check::StaticInit::IntInit(x) => { + write!(f, " .data\n")?; + write!(f, " .align {}\n", alignment)?; + write!(f, "{}:\n .long {}\n", name, x) + } + crate::frontend::type_check::StaticInit::LongInit(x) => { + write!(f, " .data\n")?; + write!(f, " .align {}\n", alignment)?; + write!(f, "{}:\n .quad {}\n", name, x) + } } } }@@ -56,20 +68,60 @@
impl fmt::Display for Instruction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Instruction::Mov(src, dst) => write!(f, "movl {}, {}", src, dst), + Instruction::Mov(t, src, dst) => write!( + f, + "mov{} {}, {}", + t, + get_typed_operand(t, src), + get_typed_operand(t, dst) + ), + Instruction::Movsx(src, dst) => write!( + f, + "movslq {}, {}", + get_lword_operand(src), + get_qword_operand(dst) + ), Instruction::Ret => write!(f, "movq %rbp, %rsp\n popq %rbp\n ret"), - Instruction::Unary(un_op, operand) => write!(f, "{} {}", un_op, operand), - Instruction::AllocStack(x) => write!(f, "subq ${}, %rsp", x), - Instruction::DeallocStack(x) => write!(f, "addq ${}, %rsp", x), - Instruction::Binary(bin_op @ (BinOp::LShift | BinOp::RShift), operand1, operand2) => { - write!(f, "{} {}, {}", bin_op, get_byte_operand(operand1), operand2) + Instruction::Unary(un_op, t, operand) => { + write!(f, "{}{} {}", un_op, t, get_typed_operand(t, operand)) } - Instruction::Binary(bin_op, operand1, operand2) => { - write!(f, "{} {}, {}", bin_op, operand1, operand2) + Instruction::Binary( + bin_op @ (BinOp::LShift | BinOp::RShift), + t, + operand1, + operand2, + ) => { + write!( + f, + "{}{} {}, {}", + bin_op, + t, + get_byte_operand(operand1), + get_typed_operand(t, operand2) + ) } - Instruction::Idiv(operand) => write!(f, "idivl {}", operand), - Instruction::Cdq => write!(f, "cdq"), - Instruction::Cmp(op1, op2) => write!(f, "cmpl {}, {}", op1, op2), + Instruction::Binary(bin_op, t, operand1, operand2) => { + write!( + f, + "{}{} {}, {}", + bin_op, + t, + get_typed_operand(t, operand1), + get_typed_operand(t, operand2) + ) + } + Instruction::Idiv(t, operand) => { + write!(f, "idiv{} {}", t, get_typed_operand(t, operand)) + } + Instruction::Cdq(AssemblyType::Longword) => write!(f, "cdq"), + Instruction::Cdq(AssemblyType::Quadword) => write!(f, "cqo"), + Instruction::Cmp(t, op1, op2) => write!( + f, + "cmp{} {}, {}", + t, + get_typed_operand(t, op1), + get_typed_operand(t, op2) + ), Instruction::Jump(label) => write!(f, "jmp .L{}", label), Instruction::JmpCC(condition, label) => write!(f, "j{} .L{}", condition, label), Instruction::SetCC(condition, operand) => {@@ -78,6 +130,15 @@ }
Instruction::Label(label) => write!(f, ".L{}:", label), Instruction::Push(operand) => write!(f, "pushq {}", get_qword_operand(operand)), Instruction::Call(func) => write!(f, "call {}@PLT", func), + } + } +} + +impl fmt::Display for AssemblyType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + AssemblyType::Longword => write!(f, "l"), + AssemblyType::Quadword => write!(f, "q"), } } }@@ -86,10 +147,10 @@ impl fmt::Display for Operand {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Operand::Immediate(x) => write!(f, "${}", x), - Operand::Register(reg) => write!(f, "%{}", reg), Operand::Pseudo(x) => write!(f, "PSEUDO {}", x), Operand::Stack(x) => write!(f, "{}(%rbp)", x), Operand::Data(x) => write!(f, "{}(%rip)", x), + Operand::Register(_) => Ok(()), } } }@@ -102,6 +163,7 @@ Reg::RCX => "%cl".to_string(),
Reg::RDX => "%dl".to_string(), Reg::RDI => "%dil".to_string(), Reg::RSI => "%sil".to_string(), + Reg::RSP => "%spl".to_string(), Reg::R8 => "%r8b".to_string(), Reg::R9 => "%r9b".to_string(), Reg::R10 => "%r10b".to_string(),@@ -119,6 +181,7 @@ Reg::RCX => "%rcx".to_string(),
Reg::RDX => "%rdx".to_string(), Reg::RDI => "%rdi".to_string(), Reg::RSI => "%rsi".to_string(), + Reg::RSP => "%rsp".to_string(), Reg::R8 => "%r8".to_string(), Reg::R9 => "%r9".to_string(), Reg::R10 => "%r10".to_string(),@@ -128,29 +191,38 @@ x => x.to_string(),
} } -impl fmt::Display for Reg { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Reg::RAX => write!(f, "eax"), - Reg::RCX => write!(f, "ecx"), - Reg::RDX => write!(f, "edx"), - Reg::RDI => write!(f, "edi"), - Reg::RSI => write!(f, "esi"), - Reg::R8 => write!(f, "r8d"), - Reg::R9 => write!(f, "r9d"), - Reg::R10 => write!(f, "r10d"), - Reg::R11 => write!(f, "r11d"), - } +fn get_lword_operand(operand: &Operand) -> String { + match operand { + Operand::Register(reg) => match reg { + Reg::RAX => "%eax".to_string(), + Reg::RCX => "%ecx".to_string(), + Reg::RDX => "%edx".to_string(), + Reg::RDI => "%edi".to_string(), + Reg::RSI => "%esi".to_string(), + Reg::RSP => "%esp".to_string(), + Reg::R8 => "%r8d".to_string(), + Reg::R9 => "%r9d".to_string(), + Reg::R10 => "%r10d".to_string(), + Reg::R11 => "%r11d".to_string(), + }, + x => x.to_string(), + } +} + +fn get_typed_operand(asm_type: &AssemblyType, operand: &Operand) -> String { + match asm_type { + AssemblyType::Longword => get_lword_operand(operand), + AssemblyType::Quadword => get_qword_operand(operand), } } impl fmt::Display for UnOp { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - UnOp::Neg => write!(f, "negl"), - UnOp::Not => write!(f, "notl"), - UnOp::Inc => write!(f, "incl"), - UnOp::Dec => write!(f, "decl"), + UnOp::Neg => write!(f, "neg"), + UnOp::Not => write!(f, "not"), + UnOp::Inc => write!(f, "inc"), + UnOp::Dec => write!(f, "dec"), } } }@@ -158,14 +230,14 @@
impl fmt::Display for BinOp { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - BinOp::Add => write!(f, "addl"), - BinOp::Sub => write!(f, "subl"), - BinOp::Mul => write!(f, "imull"), - BinOp::And => write!(f, "andl"), - BinOp::Or => write!(f, "orl"), - BinOp::Xor => write!(f, "xorl"), - BinOp::LShift => write!(f, "sall"), - BinOp::RShift => write!(f, "sarl"), + BinOp::Add => write!(f, "add"), + BinOp::Sub => write!(f, "sub"), + BinOp::Mul => write!(f, "imul"), + BinOp::And => write!(f, "and"), + BinOp::Or => write!(f, "or"), + BinOp::Xor => write!(f, "xor"), + BinOp::LShift => write!(f, "sal"), + BinOp::RShift => write!(f, "sar"), } } }
@@ -1,5 +1,6 @@
+use crate::frontend::ast::Type; use crate::frontend::ir; -use crate::frontend::type_check::{IdentifierAttributes, Type}; +use crate::frontend::type_check::{IdentifierAttributes, StaticInit}; use anyhow::{Result, bail}; use std::cmp::min; use std::collections::HashMap;@@ -12,23 +13,22 @@
#[derive(Debug, PartialEq)] pub enum TopLevel { Function(String, bool, Vec<Instruction>), - StaticVariable(String, bool, i32), + StaticVariable(String, bool, i32, StaticInit), } #[derive(Debug, PartialEq)] pub enum Instruction { - Mov(Operand, Operand), - Unary(UnOp, Operand), - Binary(BinOp, Operand, Operand), - Cmp(Operand, Operand), - Idiv(Operand), - Cdq, + Mov(AssemblyType, Operand, Operand), + Movsx(Operand, Operand), + Unary(UnOp, AssemblyType, Operand), + Binary(BinOp, AssemblyType, Operand, Operand), + Cmp(AssemblyType, Operand, Operand), + Idiv(AssemblyType, Operand), + Cdq(AssemblyType), Jump(String), JmpCC(Condition, String), SetCC(Condition, Operand), Label(String), - AllocStack(i32), - DeallocStack(i32), Push(Operand), Call(String), Ret,@@ -36,20 +36,27 @@ }
#[derive(Debug, PartialEq, Clone)] pub enum Operand { - Immediate(i32), + Immediate(i64), Register(Reg), Pseudo(String), - Stack(i32), + Stack(i64), Data(String), } #[derive(Debug, PartialEq, Clone)] +pub enum AssemblyType { + Longword, + Quadword, +} + +#[derive(Debug, PartialEq, Clone)] pub enum Reg { RAX, RCX, RDX, RDI, RSI, + RSP, R8, R9, R10,@@ -86,9 +93,47 @@ LShift,
RShift, } +fn operand_to_asm_type( + operand: &ir::Operand, + symbol_table: &HashMap<String, (Type, IdentifierAttributes)>, +) -> Result<AssemblyType> { + match operand { + ir::Operand::Constant(c) => match c { + crate::frontend::ast::Const::Int(_) => Ok(AssemblyType::Longword), + crate::frontend::ast::Const::Long(_) => Ok(AssemblyType::Quadword), + }, + ir::Operand::Variable(var) => match symbol_table.get(var) { + Some(x) => match x.0 { + Type::Int => Ok(AssemblyType::Longword), + Type::Long => Ok(AssemblyType::Quadword), + Type::Function(_, _) => bail!("{} is no variable", var), + }, + None => bail!("Variable {} not found in symbol table", var), + }, + } +} + +// TODO: refactor symb rtable util function +fn get_asm_type_from_symbol_table( + name: &String, + symbol_table: &HashMap<String, (Type, IdentifierAttributes)>, +) -> Result<AssemblyType> { + match symbol_table.get(name) { + Some((var_type, _)) => match var_type { + Type::Int => Ok(AssemblyType::Longword), + Type::Long => Ok(AssemblyType::Quadword), + Type::Function(_, _) => bail!("{} is no variable", name), + }, + None => bail!("Could not find {} in symbol table", name), + } +} + fn parse_operand(expr: ir::Operand) -> Result<Operand> { match expr { - ir::Operand::Constant(c) => Ok(Operand::Immediate(c)), + ir::Operand::Constant(c) => match c { + crate::frontend::ast::Const::Int(i) => Ok(Operand::Immediate(i as i64)), + crate::frontend::ast::Const::Long(l) => Ok(Operand::Immediate(l)), + }, ir::Operand::Variable(var) => Ok(Operand::Pseudo(var)), } }@@ -134,12 +179,18 @@ name: String,
args: Vec<ir::Operand>, ret: ir::Operand, instructions: &mut Vec<Instruction>, + symbol_table: &HashMap<String, (Type, IdentifierAttributes)>, ) -> Result<()> { let reg_args_mapping = vec![Reg::RDI, Reg::RSI, Reg::RDX, Reg::RCX, Reg::R8, Reg::R9]; let reg_args = &args[..min(6, args.len())]; let stack_args = &args[min(6, args.len())..]; let stack_padding = if stack_args.len() % 2 == 1 { - instructions.push(Instruction::AllocStack(8)); + instructions.push(Instruction::Binary( + BinOp::Sub, + AssemblyType::Quadword, + Operand::Immediate(8), + Operand::Register(Reg::RSP), + )); 8 } else { 0@@ -149,75 +200,120 @@ let mut index = 0;
for arg in reg_args { let reg = reg_args_mapping[index].clone(); let asm_arg = parse_operand(arg.clone())?; - instructions.push(Instruction::Mov(asm_arg, Operand::Register(reg))); + instructions.push(Instruction::Mov( + operand_to_asm_type(arg, symbol_table)?, + asm_arg, + Operand::Register(reg), + )); index += 1; } for arg in stack_args.into_iter().rev() { let asm_arg = parse_operand(arg.clone())?; + let asm_type = operand_to_asm_type(arg, symbol_table)?; match asm_arg { Operand::Immediate(_) | Operand::Register(_) => { instructions.push(Instruction::Push(asm_arg)) } _ => { - instructions.push(Instruction::Mov(asm_arg, Operand::Register(Reg::RAX))); - instructions.push(Instruction::Push(Operand::Register(Reg::RAX))) + if asm_type == AssemblyType::Quadword { + instructions.push(Instruction::Push(asm_arg)) + } else { + instructions.push(Instruction::Mov( + AssemblyType::Longword, + asm_arg, + Operand::Register(Reg::RAX), + )); + instructions.push(Instruction::Push(Operand::Register(Reg::RAX))) + } } } } instructions.push(Instruction::Call(name)); - instructions.push(Instruction::DeallocStack( - (8 * stack_args.len() + stack_padding) as i32, - )); + let rsp_restore_bytes = (8 * stack_args.len() + stack_padding) as i64; + if rsp_restore_bytes != 0 { + instructions.push(Instruction::Binary( + BinOp::Add, + AssemblyType::Quadword, + Operand::Immediate(rsp_restore_bytes), + Operand::Register(Reg::RSP), + )); + } + let ret_type = operand_to_asm_type(&ret, symbol_table)?; let ret = parse_operand(ret)?; - instructions.push(Instruction::Mov(Operand::Register(Reg::RAX), ret)); + instructions.push(Instruction::Mov(ret_type, Operand::Register(Reg::RAX), ret)); Ok(()) } -fn parse_instructions(instructions: Vec<ir::Instruction>) -> Result<Vec<Instruction>> { - let mut result = Vec::new(); +fn parse_instructions( + instructions: Vec<ir::Instruction>, + symbol_table: &HashMap<String, (Type, IdentifierAttributes)>, +) -> Result<Vec<Instruction>> { + let mut result: Vec<Instruction> = Vec::new(); for instr in instructions { match instr { ir::Instruction::Unary(ir::UnOp::Not, src, dst) => { + let asm_type = operand_to_asm_type(&src, symbol_table)?; let dst = parse_operand(dst)?; - result.push(Instruction::Cmp(Operand::Immediate(0), parse_operand(src)?)); - result.push(Instruction::Mov(Operand::Immediate(0), dst.clone())); + result.push(Instruction::Cmp( + asm_type.clone(), + Operand::Immediate(0), + parse_operand(src)?, + )); + result.push(Instruction::Mov( + asm_type, + Operand::Immediate(0), + dst.clone(), + )); result.push(Instruction::SetCC(Condition::E, dst)); } ir::Instruction::Unary(un_op, src, dst) => { + let asm_type = operand_to_asm_type(&src, symbol_table)?; let dst = parse_operand(dst)?; - result.push(Instruction::Mov(parse_operand(src)?, dst.clone())); - result.push(Instruction::Unary(parse_unary(un_op)?, dst)); + result.push(Instruction::Mov( + asm_type.clone(), + parse_operand(src)?, + dst.clone(), + )); + result.push(Instruction::Unary(parse_unary(un_op)?, asm_type, dst)); } - ir::Instruction::Ret(value) => { + ir::Instruction::Ret(op) => { + let asm_type = operand_to_asm_type(&op, symbol_table)?; result.push(Instruction::Mov( - parse_operand(value)?, + asm_type, + parse_operand(op)?, Operand::Register(Reg::RAX), )); result.push(Instruction::Ret); } ir::Instruction::Binary(ir::BinOp::Division, src1, src2, dst) => { + let asm_type = operand_to_asm_type(&src1, symbol_table)?; result.push(Instruction::Mov( + asm_type.clone(), parse_operand(src1)?, Operand::Register(Reg::RAX), )); - result.push(Instruction::Cdq); - result.push(Instruction::Idiv(parse_operand(src2)?)); + result.push(Instruction::Cdq(asm_type.clone())); + result.push(Instruction::Idiv(asm_type.clone(), parse_operand(src2)?)); result.push(Instruction::Mov( + asm_type, Operand::Register(Reg::RAX), parse_operand(dst)?, )); } ir::Instruction::Binary(ir::BinOp::Modulo, src1, src2, dst) => { + let asm_type = operand_to_asm_type(&src1, symbol_table)?; result.push(Instruction::Mov( + asm_type.clone(), parse_operand(src1)?, Operand::Register(Reg::RAX), )); - result.push(Instruction::Cdq); - result.push(Instruction::Idiv(parse_operand(src2)?)); + result.push(Instruction::Cdq(asm_type.clone())); + result.push(Instruction::Idiv(asm_type.clone(), parse_operand(src2)?)); result.push(Instruction::Mov( + asm_type, Operand::Register(Reg::RDX), parse_operand(dst)?, ));@@ -228,14 +324,21 @@ src1,
src2, dst, ) => { + let asm_type = operand_to_asm_type(&src1, symbol_table)?; let dst = parse_operand(dst)?; - result.push(Instruction::Mov(parse_operand(src1)?, dst.clone())); result.push(Instruction::Mov( + asm_type.clone(), + parse_operand(src1)?, + dst.clone(), + )); + result.push(Instruction::Mov( + asm_type.clone(), parse_operand(src2)?, Operand::Register(Reg::RCX), )); result.push(Instruction::Binary( parse_binary(bin_op)?, + asm_type, Operand::Register(Reg::RCX), dst, ));@@ -251,33 +354,57 @@ src1,
src2, dst, ) => { + let asm_type = operand_to_asm_type(&src1, symbol_table)?; let dst = parse_operand(dst)?; - result.push(Instruction::Cmp(parse_operand(src2)?, parse_operand(src1)?)); - result.push(Instruction::Mov(Operand::Immediate(0), dst.clone())); + result.push(Instruction::Cmp( + asm_type.clone(), + parse_operand(src2)?, + parse_operand(src1)?, + )); + result.push(Instruction::Mov( + asm_type.clone(), + Operand::Immediate(0), + dst.clone(), + )); result.push(Instruction::SetCC(parse_condition(bin_op)?, dst)); } ir::Instruction::Binary(bin_op, src1, src2, dst) => { + let asm_type = operand_to_asm_type(&src1, symbol_table)?; let dst = parse_operand(dst)?; - result.push(Instruction::Mov(parse_operand(src1)?, dst.clone())); + result.push(Instruction::Mov( + asm_type.clone(), + parse_operand(src1)?, + dst.clone(), + )); result.push(Instruction::Binary( parse_binary(bin_op)?, + asm_type, parse_operand(src2)?, dst, )); } ir::Instruction::Copy(src, dst) => { - result.push(Instruction::Mov(parse_operand(src)?, parse_operand(dst)?)) + let asm_type = operand_to_asm_type(&src, symbol_table)?; + result.push(Instruction::Mov( + asm_type, + parse_operand(src)?, + parse_operand(dst)?, + )) } ir::Instruction::Jump(label) => result.push(Instruction::Jump(label)), ir::Instruction::JumpIfZero(operand, label) => { + let asm_type = operand_to_asm_type(&operand, symbol_table)?; result.push(Instruction::Cmp( + asm_type.clone(), Operand::Immediate(0), parse_operand(operand)?, )); result.push(Instruction::JmpCC(Condition::E, label)); } ir::Instruction::JumpIfNotZero(operand, label) => { + let asm_type = operand_to_asm_type(&operand, symbol_table)?; result.push(Instruction::Cmp( + asm_type.clone(), Operand::Immediate(0), parse_operand(operand)?, ));@@ -285,72 +412,96 @@ result.push(Instruction::JmpCC(Condition::NE, label));
} ir::Instruction::Label(label) => result.push(Instruction::Label(label)), ir::Instruction::FunctionCall(name, params, ret) => { - parse_function_call(name, params, ret, &mut result)? + parse_function_call(name, params, ret, &mut result, symbol_table)? + } + ir::Instruction::SignExtend(src, dst) => { + result.push(Instruction::Movsx(parse_operand(src)?, parse_operand(dst)?)); + } + ir::Instruction::Truncate(src, dst) => { + result.push(Instruction::Mov( + AssemblyType::Longword, + parse_operand(src)?, + parse_operand(dst)?, + )); } } } Ok(result) } -fn parse_function(fun: ir::TopLevel) -> Result<TopLevel> { +fn parse_function( + fun: ir::TopLevel, + symbol_table: &HashMap<String, (Type, IdentifierAttributes)>, +) -> Result<TopLevel> { match fun { ir::TopLevel::Function(name, global, params, body) => { let reg_args_mapping = vec![Reg::RDI, Reg::RSI, Reg::RDX, Reg::RCX, Reg::R8, Reg::R9]; let mut instructions: Vec<Instruction> = Vec::new(); let mut index = 0; for param in params { + let asm_type = get_asm_type_from_symbol_table(¶m, symbol_table)?; if index == reg_args_mapping.len() { index = 16; } if index < reg_args_mapping.len() { instructions.push(Instruction::Mov( + asm_type, Operand::Register(reg_args_mapping[index].clone()), Operand::Pseudo(param), )); index += 1; } else if index > reg_args_mapping.len() { instructions.push(Instruction::Mov( - Operand::Stack(index as i32), + asm_type, + Operand::Stack(index as i64), Operand::Pseudo(param), )); index += 8; } } - instructions.append(&mut parse_instructions(body)?); + instructions.append(&mut parse_instructions(body, symbol_table)?); Ok(TopLevel::Function(name, global, instructions)) } - ir::TopLevel::StaticVariable(name, global, init) => { - Ok(TopLevel::StaticVariable(name, global, init)) - } + ir::TopLevel::StaticVariable(name, global, var_type, init) => match var_type { + Type::Int => Ok(TopLevel::StaticVariable(name, global, 4, init)), + Type::Long => Ok(TopLevel::StaticVariable(name, global, 8, init)), + Type::Function(_, _) => bail!("Found toplevel function"), + }, } } fn replace_pseudo_operand( operand: Operand, - hash_map: &mut HashMap<String, i32>, + hash_map: &mut HashMap<String, i64>, symbol_table: &HashMap<String, (Type, IdentifierAttributes)>, ) -> Result<Operand> { match operand { Operand::Pseudo(key) => match hash_map.get(&key) { - Some(x) => Ok(Operand::Stack(*x)), + Some(x) => Ok(Operand::Stack(*x as i64)), None => { - if let Some((_, symb_attr)) = symbol_table.get(&key) { + let mut size_on_stack: i64 = 0; + if let Some((var_type, symb_attr)) = symbol_table.get(&key) { match symb_attr { IdentifierAttributes::StaticAttributes(_, _) => { return Ok(Operand::Data(key)); } _ => (), }; + match var_type { + Type::Int => size_on_stack = 4, + Type::Long => size_on_stack = 8, + Type::Function(_, _) => bail!("Function not allowed here"), + }; } let n_val = match hash_map.iter().min_by_key(|entry| entry.1) { - Some(x) => x.1 - 4, - None => -4, + Some(x) => (x.1 - size_on_stack) - (8 + ((x.1 - size_on_stack) % 8)), + None => -8, }; hash_map.insert(key, n_val); - Ok(Operand::Stack(n_val)) + Ok(Operand::Stack(n_val as i64)) } }, _ => Ok(operand),@@ -366,47 +517,55 @@ match asm {
Asm::Program(functions) => { for function in functions { let mut instr_pseudoless: Vec<Instruction> = Vec::new(); - let mut hash_map: HashMap<String, i32> = HashMap::new(); + let mut hash_map: HashMap<String, i64> = HashMap::new(); match function { TopLevel::Function(name, global, instructions) => { for instr in instructions { match instr { - Instruction::Mov(op1, op2) => { + Instruction::Mov(asm_type, op1, op2) => { + let n_op1 = + replace_pseudo_operand(op1, &mut hash_map, symbol_table)?; + let n_op2 = + replace_pseudo_operand(op2, &mut hash_map, symbol_table)?; + instr_pseudoless.push(Instruction::Mov(asm_type, n_op1, n_op2)); + } + Instruction::Movsx(op1, op2) => { let n_op1 = replace_pseudo_operand(op1, &mut hash_map, symbol_table)?; let n_op2 = replace_pseudo_operand(op2, &mut hash_map, symbol_table)?; - instr_pseudoless.push(Instruction::Mov(n_op1, n_op2)); + instr_pseudoless.push(Instruction::Movsx(n_op1, n_op2)); } - Instruction::Cmp(op1, op2) => { + Instruction::Cmp(asm_type, op1, op2) => { let n_op1 = replace_pseudo_operand(op1, &mut hash_map, symbol_table)?; let n_op2 = replace_pseudo_operand(op2, &mut hash_map, symbol_table)?; - instr_pseudoless.push(Instruction::Cmp(n_op1, n_op2)); + instr_pseudoless.push(Instruction::Cmp(asm_type, n_op1, n_op2)); } - Instruction::Unary(un_op, op) => { + Instruction::Unary(asm_type, un_op, op) => { let n_op = replace_pseudo_operand(op, &mut hash_map, symbol_table)?; - instr_pseudoless.push(Instruction::Unary(un_op, n_op)); + instr_pseudoless + .push(Instruction::Unary(asm_type, un_op, n_op)); } Instruction::SetCC(cond, op) => { let n_op = replace_pseudo_operand(op, &mut hash_map, symbol_table)?; instr_pseudoless.push(Instruction::SetCC(cond, n_op)); } - Instruction::Binary(bin_op, op1, op2) => { + Instruction::Binary(asm_type, bin_op, op1, op2) => { let n_op1 = replace_pseudo_operand(op1, &mut hash_map, symbol_table)?; let n_op2 = replace_pseudo_operand(op2, &mut hash_map, symbol_table)?; instr_pseudoless - .push(Instruction::Binary(bin_op, n_op1, n_op2)); + .push(Instruction::Binary(asm_type, bin_op, n_op1, n_op2)); } - Instruction::Idiv(op) => { + Instruction::Idiv(asm_type, op) => { let n_op = replace_pseudo_operand(op, &mut hash_map, symbol_table)?; - instr_pseudoless.push(Instruction::Idiv(n_op)); + instr_pseudoless.push(Instruction::Idiv(asm_type, n_op)); } Instruction::Push(op) => { let n_op =@@ -420,7 +579,15 @@
match hash_map.iter().min_by_key(|entry| entry.1) { Some(x) => { let padded = (-*x.1 + 15) / 16 * 16; - instr_pseudoless.insert(0, Instruction::AllocStack(padded)); + instr_pseudoless.insert( + 0, + Instruction::Binary( + BinOp::Sub, + AssemblyType::Quadword, + Operand::Immediate(padded), + Operand::Register(Reg::RSP), + ), + ); } None => (), };@@ -441,6 +608,27 @@ _ => false,
} } +fn is_immediate(op: &Operand) -> bool { + match op { + Operand::Immediate(_) => true, + _ => false, + } +} + +fn is_not_int_range(op: &Operand) -> bool { + match op { + Operand::Immediate(val) => { + if *val <= 2_i64.pow(31) - 1 && *val >= -2_i64.pow(31) { + return false; + } + true + } + Operand::Register(_) | Operand::Pseudo(_) | Operand::Stack(_) | Operand::Data(_) => false, + } +} + +// TODO: group cases into functions (i.e. both_operands_are_mem_accesses) +// for better overview fn fix_mem_accesses(asm: Asm) -> Result<Asm> { let mut new_functions: Vec<TopLevel> = Vec::new(); match asm {@@ -451,27 +639,106 @@ match function {
TopLevel::Function(name, global, instructions) => { for instr in instructions { match instr { - Instruction::Mov(op1, op2) => { + Instruction::Mov(asm_type, op1, op2) => { + let op1 = if asm_type == AssemblyType::Quadword + && is_not_int_range(&op1) + && is_mem_access(&op2) + { + new_instr.push(Instruction::Mov( + AssemblyType::Quadword, + op1, + Operand::Register(Reg::R10), + )); + Operand::Register(Reg::R10) + } else { + op1 + }; + let op1 = if asm_type == AssemblyType::Longword + && !is_not_int_range(&op1) + { + match op1 { + Operand::Immediate(x) => { + Operand::Immediate(x % 2_i64.pow(32)) + } + _ => op1, + } + } else { + op1 + }; if is_mem_access(&op1) && is_mem_access(&op2) { new_instr.push(Instruction::Mov( + asm_type.clone(), op1, Operand::Register(Reg::R10), )); new_instr.push(Instruction::Mov( + asm_type, Operand::Register(Reg::R10), op2, )); } else { - new_instr.push(Instruction::Mov(op1, op2)); + new_instr.push(Instruction::Mov(asm_type, op1, op2)); + } + } + Instruction::Movsx(op1, op2) => { + let new_op1 = if is_immediate(&op1) { + new_instr.push(Instruction::Mov( + AssemblyType::Longword, + op1, + Operand::Register(Reg::R10), + )); + Operand::Register(Reg::R10) + } else { + op1 + }; + if is_mem_access(&op2) { + new_instr.push(Instruction::Movsx( + new_op1, + Operand::Register(Reg::R11), + )); + new_instr.push(Instruction::Mov( + AssemblyType::Quadword, + Operand::Register(Reg::R11), + op2, + )); + } else { + new_instr.push(Instruction::Movsx(new_op1, op2)); } } - Instruction::Cmp(op1, op2) => { + Instruction::Push(op) => { + let op = if is_not_int_range(&op) { + new_instr.push(Instruction::Mov( + AssemblyType::Quadword, + op, + Operand::Register(Reg::R10), + )); + Operand::Register(Reg::R10) + } else { + op + }; + new_instr.push(Instruction::Push(op)); + } + Instruction::Cmp(asm_type, op1, op2) => { + let op1 = if asm_type == AssemblyType::Quadword + && is_not_int_range(&op1) + { + new_instr.push(Instruction::Mov( + AssemblyType::Quadword, + op1, + Operand::Register(Reg::R10), + )); + Operand::Register(Reg::R10) + } else { + op1 + }; if is_mem_access(&op1) && is_mem_access(&op2) { new_instr.push(Instruction::Mov( + asm_type.clone(), op1, Operand::Register(Reg::R10), )); new_instr.push(Instruction::Cmp( + asm_type, Operand::Register(Reg::R10), op2, ));@@ -479,55 +746,98 @@ } else {
match op2 { Operand::Immediate(c) => { new_instr.push(Instruction::Mov( + asm_type.clone(), Operand::Immediate(c), Operand::Register(Reg::R11), )); new_instr.push(Instruction::Cmp( + asm_type, op1, Operand::Register(Reg::R11), )); } - _ => new_instr.push(Instruction::Cmp(op1, op2)), + _ => { + new_instr.push(Instruction::Cmp(asm_type, op1, op2)) + } } } } - Instruction::Idiv(Operand::Immediate(c)) => { + Instruction::Idiv(asm_type, Operand::Immediate(c)) => { new_instr.push(Instruction::Mov( + asm_type.clone(), Operand::Immediate(c), Operand::Register(Reg::R10), )); - new_instr.push(Instruction::Idiv(Operand::Register(Reg::R10))); + new_instr.push(Instruction::Idiv( + asm_type, + Operand::Register(Reg::R10), + )); } - Instruction::Binary(BinOp::Mul, op1, op2) => { + Instruction::Binary(BinOp::Mul, asm_type, op1, op2) => { + let op1 = if asm_type == AssemblyType::Quadword + && is_not_int_range(&op1) + { + new_instr.push(Instruction::Mov( + AssemblyType::Quadword, + op1, + Operand::Register(Reg::R10), + )); + Operand::Register(Reg::R10) + } else { + op1 + }; if is_mem_access(&op2) { new_instr.push(Instruction::Mov( + asm_type.clone(), op2.clone(), Operand::Register(Reg::R11), )); new_instr.push(Instruction::Binary( BinOp::Mul, + asm_type.clone(), op1, Operand::Register(Reg::R11), )); new_instr.push(Instruction::Mov( + asm_type, Operand::Register(Reg::R11), op2, )); } } - Instruction::Binary(bin_op, op1, op2) => { + Instruction::Binary(bin_op, asm_type, op1, op2) => { + let op1 = if (bin_op == BinOp::Sub + || bin_op == BinOp::Add + || bin_op == BinOp::And + || bin_op == BinOp::Or + || bin_op == BinOp::Xor) + && asm_type == AssemblyType::Quadword + && is_not_int_range(&op1) + { + new_instr.push(Instruction::Mov( + AssemblyType::Quadword, + op1, + Operand::Register(Reg::R10), + )); + Operand::Register(Reg::R10) + } else { + op1 + }; if is_mem_access(&op1) && is_mem_access(&op2) { new_instr.push(Instruction::Mov( + asm_type.clone(), op1, Operand::Register(Reg::R10), )); new_instr.push(Instruction::Binary( bin_op, + asm_type, Operand::Register(Reg::R10), op2, )); } else { - new_instr.push(Instruction::Binary(bin_op, op1, op2)); + new_instr + .push(Instruction::Binary(bin_op, asm_type, op1, op2)); } } x => new_instr.push(x),@@ -551,7 +861,7 @@ let mut asm = match prog {
ir::TAC::Program(fun) => { let mut functions = Vec::new(); for f in fun { - functions.push(parse_function(f)?); + functions.push(parse_function(f, symbol_table)?); } Asm::Program(functions) }
@@ -107,7 +107,7 @@
let analyzed_ast = identifier_resolution::variable_resolution(ast)?; let analyzed_ast = label_resolution::label_resolution(analyzed_ast)?; let analyzed_ast = loop_resolution::loop_resolution(analyzed_ast)?; - let (analyzed_ast, symbol_table) = type_check::type_check(analyzed_ast)?; + let (analyzed_ast, mut symbol_table) = type_check::type_check(analyzed_ast)?; if args.validate { println!("Analyzed Ast for {:?}:", file);@@ -115,7 +115,7 @@ println!("{:?}", analyzed_ast);
continue; } - let ir = ir::lift_to_ir(analyzed_ast, &symbol_table)?; + let ir = ir::lift_to_ir(analyzed_ast, &mut symbol_table)?; if args.tacky { println!("IR for {:?}:", file);
@@ -1,3 +1,4 @@
+pub mod ast; pub mod identifier_resolution; pub mod ir; pub mod label_resolution;
@@ -0,0 +1,148 @@
+#[derive(Debug, PartialEq)] +pub enum Ast { + Program(Vec<Declaration>), +} + +#[derive(Debug, PartialEq)] +pub enum Declaration { + V(VariableDeclaration), + F(FunctionDeclaration), +} + +#[derive(Debug, PartialEq)] +pub enum Block { + B(Vec<BlockItem>), +} + +#[derive(Debug, PartialEq)] +pub enum BlockItem { + S(Statement), + D(Declaration), +} + +#[derive(Debug, PartialEq)] +pub enum VariableDeclaration { + D(String, Option<Expression>, Type, Option<StorageClass>), +} + +#[derive(Debug, PartialEq)] +pub enum FunctionDeclaration { + D( + String, + Vec<String>, + Option<Block>, + Type, + Option<StorageClass>, + ), +} + +#[derive(Debug, PartialEq, Clone)] +pub enum Type { + Int, + Long, + Function( + Vec<Box<Type>>, /* arg types */ + Box<Type>, /* ret type */ + ), +} + +#[derive(Debug, PartialEq)] +pub enum StorageClass { + Static, + Extern, +} + +#[derive(Debug, PartialEq)] +pub enum Statement { + Return(Expression), + Expression(Expression), + If(Expression, Box<Statement>, Option<Box<Statement>>), + While(Expression, Box<Statement>, Option<String>), + DoWhile(Box<Statement>, Expression, Option<String>), + For( + ForInit, + Option<Expression>, + Option<Expression>, + Box<Statement>, + Option<String>, + ), + Switch( + Expression, + Box<Statement>, + Option<String>, + Vec<(Option<Const>, String)>, + ), + Default(Box<Statement>, Option<String>), + Case(Expression, Box<Statement>, Option<String>), + Continue(Option<String>), + Break(Option<String>), + Compound(Block), + Labeled(String, Box<Statement>), + Goto(String), + Null, +} + +#[derive(Debug, PartialEq)] +pub enum ForInit { + D(VariableDeclaration), + E(Option<Expression>), +} + +#[derive(Debug, PartialEq)] +pub enum Expression { + Constant(Const, Option<Type>), + Variable(String, Option<Type>), + Cast(Type, Box<Expression>, Option<Type>), /* target type; expression to cast */ + Unary(UnaryOp, Box<Expression>, Option<Type>), + Binary(BinaryOp, Box<Expression>, Box<Expression>, Option<Type>), + Assignment(Box<Expression>, Box<Expression>, Option<Type>), + CompoundAssignment(BinaryOp, Box<Expression>, Box<Expression>, Option<Type>), + PostIncr(Box<Expression>, Option<Type>), + PostDecr(Box<Expression>, Option<Type>), + Conditional( + Box<Expression>, + Box<Expression>, + Box<Expression>, + Option<Type>, + ), + FunctionCall(String, Vec<Box<Expression>>, Option<Type>), +} + +#[derive(Debug, PartialEq, Clone)] +pub enum Const { + Int(i32), + Long(i64), +} + +#[derive(Debug, PartialEq)] +pub enum UnaryOp { + Complement, + Negation, + Not, + Increment, + Decrement, +} + +#[derive(Debug, PartialEq)] +pub enum BinaryOp { + Addition, + Subtraction, + Multiplication, + Division, + Modulo, + And, + Or, + Xor, + LShift, + RShift, + LAnd, + LOr, + Equal, + NEqual, + Less, + Greater, + LessEq, + GreaterEq, + + Assignment, +}
@@ -3,7 +3,7 @@ collections::HashMap,
sync::atomic::{AtomicUsize, Ordering}, }; -use crate::frontend::parse::{ +use crate::frontend::ast::{ Ast, Block, BlockItem, Declaration, Expression, ForInit, FunctionDeclaration, Statement, StorageClass, UnaryOp, VariableDeclaration, };@@ -19,10 +19,8 @@ fn copy_hashmap(
hash_map: &HashMap<String, (String, bool, bool)>, ) -> HashMap<String, (String, bool, bool)> { let mut result: HashMap<String, (String, bool, bool)> = HashMap::new(); - for key in hash_map.keys() { - let value = hash_map.get(key).unwrap().0.clone(); - let linkage = hash_map.get(key).unwrap().2.clone(); - result.insert(key.clone(), (value, false, linkage)); + for item in hash_map { + result.insert(item.0.clone(), (item.1.0.clone(), false, item.1.2.clone())); } result }@@ -32,70 +30,77 @@ expr: Expression,
hash_map: &mut HashMap<String, (String, bool, bool)>, ) -> Result<Expression> { match expr { - Expression::Variable(x) => { + Expression::Variable(x, _) => { if let Some(r) = hash_map.get(&x) { - return Ok(Expression::Variable(r.0.to_string())); + return Ok(Expression::Variable(r.0.to_string(), None)); } else { bail!("Undeclared variable {}", x); } } - Expression::Unary(unary_op, expression) => { + Expression::Unary(unary_op, expression, _) => { if (matches!(unary_op, UnaryOp::Increment) || matches!(unary_op, UnaryOp::Decrement)) - && !matches!(*expression, Expression::Variable(_)) + && !matches!(*expression, Expression::Variable(_, _)) { bail!("{:?} is not a valid lvalue", *expression); } Ok(Expression::Unary( unary_op, Box::new(resolve_expression(*expression, hash_map)?), + None, )) } - Expression::Binary(binary_op, left, right) => Ok(Expression::Binary( + Expression::Binary(binary_op, left, right, _) => Ok(Expression::Binary( binary_op, Box::new(resolve_expression(*left, hash_map)?), Box::new(resolve_expression(*right, hash_map)?), + None, )), - Expression::Assignment(left, right) => { - if !matches!(*left, Expression::Variable(_)) { + Expression::Assignment(left, right, _) => { + if !matches!(*left, Expression::Variable(_, _)) { bail!("{:?} is not a valid lvalue", *left); } Ok(Expression::Assignment( Box::new(resolve_expression(*left, hash_map)?), Box::new(resolve_expression(*right, hash_map)?), + None, )) } - Expression::CompoundAssignment(binary_op, left, right) => { - if !matches!(*left, Expression::Variable(_)) { + Expression::CompoundAssignment(binary_op, left, right, _) => { + if !matches!(*left, Expression::Variable(_, _)) { bail!("{:?} is not a valid lvalue", *left); } Ok(Expression::CompoundAssignment( binary_op, Box::new(resolve_expression(*left, hash_map)?), Box::new(resolve_expression(*right, hash_map)?), + None, )) } - Expression::PostIncr(expr) => { - if !matches!(*expr, Expression::Variable(_)) { + Expression::PostIncr(expr, _) => { + if !matches!(*expr, Expression::Variable(_, _)) { bail!("{:?} is not a valid lvalue", *expr); } - Ok(Expression::PostIncr(Box::new(resolve_expression( - *expr, hash_map, - )?))) + Ok(Expression::PostIncr( + Box::new(resolve_expression(*expr, hash_map)?), + None, + )) } - Expression::PostDecr(expr) => { - if !matches!(*expr, Expression::Variable(_)) { + Expression::PostDecr(expr, _) => { + if !matches!(*expr, Expression::Variable(_, _)) { bail!("{:?} is not a valid lvalue", *expr); } - Ok(Expression::PostDecr(Box::new(resolve_expression( - *expr, hash_map, - )?))) + Ok(Expression::PostDecr( + Box::new(resolve_expression(*expr, hash_map)?), + None, + )) } - Expression::Conditional(left, middle, right) => Ok(Expression::Conditional( + Expression::Conditional(left, middle, right, _) => Ok(Expression::Conditional( Box::new(resolve_expression(*left, hash_map)?), Box::new(resolve_expression(*middle, hash_map)?), Box::new(resolve_expression(*right, hash_map)?), + None, )), - Expression::FunctionCall(name, args) => { + Expression::FunctionCall(name, args, _) => { if !hash_map.contains_key(&name) { bail!("Undeclared identifier {}", name); }@@ -108,8 +113,14 @@
Ok(Expression::FunctionCall( hash_map.get(&name).unwrap().0.to_string(), new_args, + None, )) } + Expression::Cast(var_type, expression, _) => Ok(Expression::Cast( + var_type, + Box::new(resolve_expression(*expression, hash_map)?), + None, + )), c => Ok(c), } }@@ -223,7 +234,7 @@ decl: VariableDeclaration,
hash_map: &mut HashMap<String, (String, bool, bool)>, ) -> Result<VariableDeclaration> { match decl { - VariableDeclaration::D(id, opt_expression, storage_class) => { + VariableDeclaration::D(id, opt_expression, var_type, storage_class) => { if let Some(prev_entry) = hash_map.get(&id) { if prev_entry.1 && !(prev_entry.2 && matches!(storage_class, Some(StorageClass::Extern)))@@ -234,16 +245,31 @@ }
if matches!(storage_class, Some(StorageClass::Extern)) { hash_map.insert(id.clone(), (id.clone(), true, true)); - return Ok(VariableDeclaration::D(id, opt_expression, storage_class)); + return Ok(VariableDeclaration::D( + id, + opt_expression, + var_type, + storage_class, + )); } let unique = gen_temp_local(id.clone()); hash_map.insert(id, (unique.clone(), true, false)); if let Some(expr) = opt_expression { let expr = resolve_expression(expr, hash_map)?; - return Ok(VariableDeclaration::D(unique, Some(expr), storage_class)); + return Ok(VariableDeclaration::D( + unique, + Some(expr), + var_type, + storage_class, + )); } - Ok(VariableDeclaration::D(unique, None, storage_class)) + Ok(VariableDeclaration::D( + unique, + None, + var_type, + storage_class, + )) } } }@@ -265,9 +291,14 @@ decl: VariableDeclaration,
hash_map: &mut HashMap<String, (String, bool, bool)>, ) -> Result<VariableDeclaration> { match decl { - VariableDeclaration::D(id, opt_expression, storage_class) => { + VariableDeclaration::D(id, opt_expression, var_type, storage_class) => { hash_map.insert(id.clone(), (id.clone(), true, true)); - Ok(VariableDeclaration::D(id, opt_expression, storage_class)) + Ok(VariableDeclaration::D( + id, + opt_expression, + var_type, + storage_class, + )) } } }@@ -277,7 +308,7 @@ decl: FunctionDeclaration,
hash_map: &mut HashMap<String, (String, bool, bool)>, ) -> Result<FunctionDeclaration> { match decl { - FunctionDeclaration::D(name, params, block, storage_class) => { + FunctionDeclaration::D(name, params, block, var_type, storage_class) => { if hash_map.contains_key(&name) && hash_map.get(&name).unwrap().1 && !hash_map.get(&name).unwrap().2@@ -301,6 +332,7 @@ Ok(FunctionDeclaration::D(
name, new_params, new_block, + var_type, storage_class, )) }@@ -312,7 +344,7 @@ decl: FunctionDeclaration,
hash_map: &mut HashMap<String, (String, bool, bool)>, ) -> Result<FunctionDeclaration> { match &decl { - FunctionDeclaration::D(_, _, block, storage_class) => { + FunctionDeclaration::D(_, _, block, _, storage_class) => { if *block != None { bail!("Function definition inside another function is not allowed!"); }
@@ -6,8 +6,8 @@ sync::atomic::{AtomicUsize, Ordering},
}; use crate::frontend::{ - parse, - type_check::{IdentifierAttributes, Type}, + ast::{self, Type}, + type_check::{IdentifierAttributes, StaticInit, get_expression_type}, }; #[derive(Debug, PartialEq)]@@ -18,7 +18,12 @@
#[derive(Debug, PartialEq)] pub enum TopLevel { Function(String, bool, Vec<String>, Vec<Instruction>), - StaticVariable(String, bool, i32 /* name, global, init_value */), + StaticVariable( + String, + bool, + Type, + StaticInit, /* name, global, init_value */ + ), } #[derive(Debug, PartialEq)]@@ -31,12 +36,14 @@ JumpIfZero(Operand, String),
JumpIfNotZero(Operand, String), Label(String), Ret(Operand), + SignExtend(Operand, Operand), + Truncate(Operand, Operand), FunctionCall(String, Vec<Operand>, Operand), } #[derive(Debug, PartialEq, Clone)] pub enum Operand { - Constant(i32), + Constant(ast::Const), Variable(String), }@@ -69,6 +76,7 @@ LessEq,
GreaterEq, } +// TODO: rewrite display impl fmt::Display for TAC { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self {@@ -85,14 +93,23 @@
impl fmt::Display for TopLevel { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - TopLevel::Function(name, params, body, _) => { + TopLevel::Function(name, _, params, body) => { write!(f, "{} {:?}:\n", name, params)?; for i in body { write!(f, " {}\n", i)?; } Ok(()) } - TopLevel::StaticVariable(name, _, val) => write!(f, "glob {} = {}", name, val), + TopLevel::StaticVariable(name, _, _, val) => write!(f, "static {} = {}", name, val), + } + } +} + +impl fmt::Display for StaticInit { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + StaticInit::IntInit(x) => write!(f, "{}", x), + StaticInit::LongInit(x) => write!(f, "{}", x), } } }@@ -117,6 +134,17 @@ Instruction::Label(label) => write!(f, "\n{}:", label),
Instruction::FunctionCall(name, params, dst) => { write!(f, "{} = Call({}, {:?})", dst, name, params) } + Instruction::SignExtend(src, dst) => write!(f, "{} SIGNEXTEND TO {}", src, dst), + Instruction::Truncate(src, dst) => write!(f, "{} TRUNCATE TO {}", src, dst), + } + } +} + +impl fmt::Display for ast::Const { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ast::Const::Int(x) => write!(f, "{}", x), + ast::Const::Long(x) => write!(f, "{}", x), } } }@@ -168,9 +196,17 @@
static COUNTER_TMP: AtomicUsize = AtomicUsize::new(0); static COUNTER_LABEL: AtomicUsize = AtomicUsize::new(0); -fn gen_temp() -> Operand { +fn gen_temp( + var_type_opt: Option<Type>, + symbol_table: &mut HashMap<String, (Type, IdentifierAttributes)>, +) -> Result<Operand> { + let Some(var_type) = var_type_opt else { + bail!("No expression type found in lifting process"); + }; let counter = COUNTER_TMP.fetch_add(1, Ordering::SeqCst); - Operand::Variable("tmp.".to_string() + &counter.to_string()) + let name = "tmp.".to_string() + &counter.to_string(); + symbol_table.insert(name.clone(), (var_type, IdentifierAttributes::LocalAttr)); + Ok(Operand::Variable(name)) } fn gen_label() -> String {@@ -178,187 +214,257 @@ let counter = COUNTER_LABEL.fetch_add(1, Ordering::SeqCst);
"_".to_string() + &counter.to_string() } -fn parse_unary_op(expr: parse::UnaryOp) -> Result<UnOp> { +fn parse_unary_op(expr: ast::UnaryOp) -> Result<UnOp> { match expr { - parse::UnaryOp::Complement => Ok(UnOp::Complement), - parse::UnaryOp::Negation => Ok(UnOp::Negation), - parse::UnaryOp::Not => Ok(UnOp::Not), - parse::UnaryOp::Increment => Ok(UnOp::Increment), - parse::UnaryOp::Decrement => Ok(UnOp::Decrement), + ast::UnaryOp::Complement => Ok(UnOp::Complement), + ast::UnaryOp::Negation => Ok(UnOp::Negation), + ast::UnaryOp::Not => Ok(UnOp::Not), + ast::UnaryOp::Increment => Ok(UnOp::Increment), + ast::UnaryOp::Decrement => Ok(UnOp::Decrement), } } -fn parse_binary_op(expr: parse::BinaryOp) -> Result<BinOp> { +fn parse_binary_op(expr: ast::BinaryOp) -> Result<BinOp> { match expr { - parse::BinaryOp::Addition => Ok(BinOp::Addition), - parse::BinaryOp::Subtraction => Ok(BinOp::Subtraction), - parse::BinaryOp::Multiplication => Ok(BinOp::Multiplication), - parse::BinaryOp::Division => Ok(BinOp::Division), - parse::BinaryOp::Modulo => Ok(BinOp::Modulo), - parse::BinaryOp::And => Ok(BinOp::And), - parse::BinaryOp::Or => Ok(BinOp::Or), - parse::BinaryOp::Xor => Ok(BinOp::Xor), - parse::BinaryOp::LShift => Ok(BinOp::LShift), - parse::BinaryOp::RShift => Ok(BinOp::RShift), - parse::BinaryOp::Equal => Ok(BinOp::Equal), - parse::BinaryOp::NEqual => Ok(BinOp::NEqual), - parse::BinaryOp::Less => Ok(BinOp::Less), - parse::BinaryOp::Greater => Ok(BinOp::Greater), - parse::BinaryOp::LessEq => Ok(BinOp::LessEq), - parse::BinaryOp::GreaterEq => Ok(BinOp::GreaterEq), + ast::BinaryOp::Addition => Ok(BinOp::Addition), + ast::BinaryOp::Subtraction => Ok(BinOp::Subtraction), + ast::BinaryOp::Multiplication => Ok(BinOp::Multiplication), + ast::BinaryOp::Division => Ok(BinOp::Division), + ast::BinaryOp::Modulo => Ok(BinOp::Modulo), + ast::BinaryOp::And => Ok(BinOp::And), + ast::BinaryOp::Or => Ok(BinOp::Or), + ast::BinaryOp::Xor => Ok(BinOp::Xor), + ast::BinaryOp::LShift => Ok(BinOp::LShift), + ast::BinaryOp::RShift => Ok(BinOp::RShift), + ast::BinaryOp::Equal => Ok(BinOp::Equal), + ast::BinaryOp::NEqual => Ok(BinOp::NEqual), + ast::BinaryOp::Less => Ok(BinOp::Less), + ast::BinaryOp::Greater => Ok(BinOp::Greater), + ast::BinaryOp::LessEq => Ok(BinOp::LessEq), + ast::BinaryOp::GreaterEq => Ok(BinOp::GreaterEq), x => bail!("{:?} should be handled seperately", x), } } fn parse_expression( - expr: parse::Expression, + expr: ast::Expression, instructions: &mut Vec<Instruction>, + symbol_table: &mut HashMap<String, (Type, IdentifierAttributes)>, ) -> Result<Operand> { match expr { - parse::Expression::Constant(c) => Ok(Operand::Constant(c)), - parse::Expression::Unary(unary_op, expression) => { - let src = parse_expression(*expression, instructions)?; + ast::Expression::Constant(c, _) => Ok(Operand::Constant(c)), + ast::Expression::Unary(unary_op, expression, var_type) => { + let src = parse_expression(*expression, instructions, symbol_table)?; let op = parse_unary_op(unary_op)?; let dst; if matches!(op, UnOp::Increment) || matches!(op, UnOp::Decrement) { dst = src.clone(); } else { - dst = gen_temp(); + dst = gen_temp(var_type, symbol_table)?; } instructions.push(Instruction::Unary(op, src, dst.clone())); Ok(dst) } - parse::Expression::Binary( - binary_op @ (parse::BinaryOp::LAnd | parse::BinaryOp::LOr), + ast::Expression::Binary( + binary_op @ (ast::BinaryOp::LAnd | ast::BinaryOp::LOr), expression1, expression2, + var_type, ) => { let label = gen_label(); let end_label = label.to_string() + "_end"; - let src1 = parse_expression(*expression1, instructions)?; - let dst = gen_temp(); - if binary_op == parse::BinaryOp::LAnd { + let src1 = parse_expression(*expression1, instructions, symbol_table)?; + let dst = gen_temp(var_type, symbol_table)?; + if binary_op == ast::BinaryOp::LAnd { instructions.push(Instruction::JumpIfZero(src1, label.clone())); } else { instructions.push(Instruction::JumpIfNotZero(src1, label.clone())); } - // TODO: try to remove the clones (references in enum?) - let src2 = parse_expression(*expression2, instructions)?; - if binary_op == parse::BinaryOp::LAnd { + let src2 = parse_expression(*expression2, instructions, symbol_table)?; + if binary_op == ast::BinaryOp::LAnd { instructions.push(Instruction::JumpIfZero(src2, label.clone())); - instructions.push(Instruction::Copy(Operand::Constant(1), dst.clone())); + instructions.push(Instruction::Copy( + Operand::Constant(ast::Const::Int(1)), + dst.clone(), + )); instructions.push(Instruction::Jump(end_label.clone())); - instructions.push(Instruction::Label(label)); - instructions.push(Instruction::Copy(Operand::Constant(0), dst.clone())); + instructions.push(Instruction::Label(label)); // TODO: look if labels are correct + instructions.push(Instruction::Copy( + Operand::Constant(ast::Const::Int(0)), + dst.clone(), + )); } else { instructions.push(Instruction::JumpIfNotZero(src2, label.clone())); - instructions.push(Instruction::Copy(Operand::Constant(0), dst.clone())); + instructions.push(Instruction::Copy( + Operand::Constant(ast::Const::Int(0)), + dst.clone(), + )); instructions.push(Instruction::Jump(end_label.clone())); instructions.push(Instruction::Label(label)); - instructions.push(Instruction::Copy(Operand::Constant(1), dst.clone())); + instructions.push(Instruction::Copy( + Operand::Constant(ast::Const::Int(1)), + dst.clone(), + )); } instructions.push(Instruction::Label(end_label)); Ok(dst) } - parse::Expression::Binary(binary_op, expression1, expression2) => { - let src1 = parse_expression(*expression1, instructions)?; - let src2 = parse_expression(*expression2, instructions)?; - let dst = gen_temp(); + ast::Expression::Binary(binary_op, expression1, expression2, var_type) => { + let src1 = parse_expression(*expression1, instructions, symbol_table)?; + let src2 = parse_expression(*expression2, instructions, symbol_table)?; + let dst = gen_temp(var_type, symbol_table)?; let op = parse_binary_op(binary_op)?; instructions.push(Instruction::Binary(op, src1, src2, dst.clone())); Ok(dst) } - parse::Expression::Variable(v) => Ok(Operand::Variable(v)), - parse::Expression::Assignment(var, right) => { - if let parse::Expression::Variable(v) = *var { - let src = parse_expression(*right, instructions)?; + ast::Expression::Variable(v, _) => Ok(Operand::Variable(v)), + ast::Expression::Assignment(var, right, _) => { + if let ast::Expression::Variable(v, _) = *var { + let src = parse_expression(*right, instructions, symbol_table)?; instructions.push(Instruction::Copy(src, Operand::Variable(v.clone()))); Ok(Operand::Variable(v)) } else { bail!("Lvalue of Assignment must be a variable!"); } } - parse::Expression::CompoundAssignment(binary_op, var, right) => { - if let parse::Expression::Variable(v) = *var { - let right = parse_expression(*right, instructions)?; - instructions.push(Instruction::Binary( - parse_binary_op(binary_op)?, - Operand::Variable(v.clone()), - right, - Operand::Variable(v.clone()), - )); + ast::Expression::CompoundAssignment(binary_op, var, right, expr_type) => { + if let ast::Expression::Variable(v, Some(var_type)) = *var { + let right_type = get_expression_type(&right)?.clone(); + let right_operand = parse_expression(*right, instructions, symbol_table)?; + if Some(var_type) == expr_type { + instructions.push(Instruction::Binary( + parse_binary_op(binary_op)?, + Operand::Variable(v.clone()), + right_operand, + Operand::Variable(v.clone()), + )); + } else { + let tmp = gen_temp(expr_type, symbol_table)?; + match right_type { + Type::Int => { + instructions.push(Instruction::Truncate( + Operand::Variable(v.clone()), + tmp.clone(), + )); + instructions.push(Instruction::Binary( + parse_binary_op(binary_op)?, + tmp.clone(), + right_operand, + tmp.clone(), + )); + instructions + .push(Instruction::SignExtend(tmp, Operand::Variable(v.clone()))); + } + Type::Long => { + instructions.push(Instruction::SignExtend( + Operand::Variable(v.clone()), + tmp.clone(), + )); + instructions.push(Instruction::Binary( + parse_binary_op(binary_op)?, + tmp.clone(), + right_operand, + tmp.clone(), + )); + instructions + .push(Instruction::Truncate(tmp, Operand::Variable(v.clone()))); + } + + Type::Function(_, _) => bail!("function?"), + } + } Ok(Operand::Variable(v)) } else { bail!("Lvalue of Assignment must be a variable!"); } } - parse::Expression::PostIncr(expr) => { - let dst = gen_temp(); - let src = parse_expression(*expr, instructions)?; + ast::Expression::PostIncr(expr, var_type) => { + let dst = gen_temp(var_type, symbol_table)?; + let src = parse_expression(*expr, instructions, symbol_table)?; instructions.push(Instruction::Copy(src.clone(), dst.clone())); instructions.push(Instruction::Binary( BinOp::Addition, src.clone(), - Operand::Constant(1), + Operand::Constant(ast::Const::Int(1)), src, )); Ok(dst) } - parse::Expression::PostDecr(expr) => { - let dst = gen_temp(); - let src = parse_expression(*expr, instructions)?; + ast::Expression::PostDecr(expr, var_type) => { + let dst = gen_temp(var_type, symbol_table)?; + let src = parse_expression(*expr, instructions, symbol_table)?; instructions.push(Instruction::Copy(src.clone(), dst.clone())); instructions.push(Instruction::Binary( BinOp::Subtraction, src.clone(), - Operand::Constant(1), + Operand::Constant(ast::Const::Int(1)), src, )); Ok(dst) } - parse::Expression::Conditional(left, middle, right) => { - let dst = gen_temp(); + ast::Expression::Conditional(left, middle, right, var_type) => { + let dst = gen_temp(var_type, symbol_table)?; let label_false = gen_label(); let end_label = label_false.to_string() + "_end"; - let cond = parse_expression(*left, instructions)?; + let cond = parse_expression(*left, instructions, symbol_table)?; instructions.push(Instruction::JumpIfZero(cond, label_false.clone())); - let middle = parse_expression(*middle, instructions)?; + let middle = parse_expression(*middle, instructions, symbol_table)?; instructions.push(Instruction::Copy(middle, dst.clone())); instructions.push(Instruction::Jump(end_label.clone())); instructions.push(Instruction::Label(label_false)); - let right = parse_expression(*right, instructions)?; + let right = parse_expression(*right, instructions, symbol_table)?; instructions.push(Instruction::Copy(right, dst.clone())); instructions.push(Instruction::Label(end_label)); Ok(dst) } - parse::Expression::FunctionCall(name, expressions) => { - let dst = gen_temp(); + ast::Expression::FunctionCall(name, expressions, var_type) => { + let dst = gen_temp(var_type, symbol_table)?; let mut operands = Vec::new(); for expr in expressions { - operands.push(parse_expression(*expr, instructions)?); + operands.push(parse_expression(*expr, instructions, symbol_table)?); } instructions.push(Instruction::FunctionCall(name, operands, dst.clone())); Ok(dst) } + ast::Expression::Cast(target_type, expression, _) => { + let expr_type = get_expression_type(&expression)?.clone(); + let result = parse_expression(*expression, instructions, symbol_table)?; + if target_type == expr_type { + return Ok(result); + } + let dst = gen_temp(Some(target_type.clone()), symbol_table)?; + if target_type == Type::Long { + instructions.push(Instruction::SignExtend(result, dst.clone())); + } else { + instructions.push(Instruction::Truncate(result, dst.clone())); + } + return Ok(dst); + } } } -fn parse_for_init(for_init: parse::ForInit, instructions: &mut Vec<Instruction>) -> Result<()> { +fn parse_for_init( + for_init: ast::ForInit, + instructions: &mut Vec<Instruction>, + symbol_table: &mut HashMap<String, (Type, IdentifierAttributes)>, +) -> Result<()> { match for_init { - parse::ForInit::D(declaration) => parse_variable_declaration(declaration, instructions), - parse::ForInit::E(opt_expression) => match opt_expression { + ast::ForInit::D(declaration) => { + parse_variable_declaration(declaration, instructions, symbol_table) + } + ast::ForInit::E(opt_expression) => match opt_expression { Some(expression) => { - parse_expression(expression, instructions)?; + parse_expression(expression, instructions, symbol_table)?; Ok(()) } None => Ok(()),@@ -366,60 +472,64 @@ },
} } -fn parse_statement(statement: parse::Statement, instructions: &mut Vec<Instruction>) -> Result<()> { +fn parse_statement( + statement: ast::Statement, + instructions: &mut Vec<Instruction>, + symbol_table: &mut HashMap<String, (Type, IdentifierAttributes)>, +) -> Result<()> { match statement { - parse::Statement::Return(expression) => { - let dst = parse_expression(expression, instructions)?; + ast::Statement::Return(expression) => { + let dst = parse_expression(expression, instructions, symbol_table)?; instructions.push(Instruction::Ret(dst)); Ok(()) } - parse::Statement::Expression(expression) => { - parse_expression(expression, instructions)?; + ast::Statement::Expression(expression) => { + parse_expression(expression, instructions, symbol_table)?; Ok(()) } - parse::Statement::Null => Ok(()), - parse::Statement::If(condition, if_statement, else_statement) => { + ast::Statement::Null => Ok(()), + ast::Statement::If(condition, if_statement, else_statement) => { let label_else = gen_label(); let end_label = label_else.to_string() + "_else"; - let cond = parse_expression(condition, instructions)?; + let cond = parse_expression(condition, instructions, symbol_table)?; instructions.push(Instruction::JumpIfZero(cond, label_else.clone())); - parse_statement(*if_statement, instructions)?; + parse_statement(*if_statement, instructions, symbol_table)?; instructions.push(Instruction::Jump(end_label.clone())); instructions.push(Instruction::Label(label_else)); if let Some(x) = else_statement { - parse_statement(*x, instructions)?; + parse_statement(*x, instructions, symbol_table)?; } instructions.push(Instruction::Label(end_label)); Ok(()) } - parse::Statement::Labeled(label, statement) => { + ast::Statement::Labeled(label, statement) => { instructions.push(Instruction::Label(label)); - parse_statement(*statement, instructions) + parse_statement(*statement, instructions, symbol_table) } - parse::Statement::Goto(label) => { + ast::Statement::Goto(label) => { instructions.push(Instruction::Jump(label)); Ok(()) } - parse::Statement::Compound(block) => parse_block(block, instructions), - parse::Statement::While(expression, statement, label) => { + ast::Statement::Compound(block) => parse_block(block, instructions, symbol_table), + ast::Statement::While(expression, statement, label) => { instructions.push(Instruction::Label(label.clone().unwrap() + "_continue")); - let cond = parse_expression(expression, instructions)?; + let cond = parse_expression(expression, instructions, symbol_table)?; instructions.push(Instruction::JumpIfZero( cond, label.clone().unwrap() + "_break", )); - parse_statement(*statement, instructions)?; + parse_statement(*statement, instructions, symbol_table)?; instructions.push(Instruction::Jump(label.clone().unwrap() + "_continue")); instructions.push(Instruction::Label(label.unwrap() + "_break")); Ok(()) } - parse::Statement::DoWhile(statement, expression, label) => { + ast::Statement::DoWhile(statement, expression, label) => { instructions.push(Instruction::Label(label.clone().unwrap() + "_start")); - parse_statement(*statement, instructions)?; + parse_statement(*statement, instructions, symbol_table)?; instructions.push(Instruction::Label(label.clone().unwrap() + "_continue")); - let cond = parse_expression(expression, instructions)?; + let cond = parse_expression(expression, instructions, symbol_table)?; instructions.push(Instruction::JumpIfNotZero( cond, label.clone().unwrap() + "_start",@@ -427,36 +537,37 @@ ));
instructions.push(Instruction::Label(label.unwrap() + "_break")); Ok(()) } - parse::Statement::For(for_init, opt_condition, opt_step, statement, label) => { - parse_for_init(for_init, instructions)?; + ast::Statement::For(for_init, opt_condition, opt_step, statement, label) => { + parse_for_init(for_init, instructions, symbol_table)?; instructions.push(Instruction::Label(label.clone().unwrap() + "_start")); if let Some(condition) = opt_condition { - let cond = parse_expression(condition, instructions)?; + let cond = parse_expression(condition, instructions, symbol_table)?; instructions.push(Instruction::JumpIfZero( cond, label.clone().unwrap() + "_break", )); } - parse_statement(*statement, instructions)?; + parse_statement(*statement, instructions, symbol_table)?; instructions.push(Instruction::Label(label.clone().unwrap() + "_continue")); if let Some(step) = opt_step { - parse_expression(step, instructions)?; + parse_expression(step, instructions, symbol_table)?; } instructions.push(Instruction::Jump(label.clone().unwrap() + "_start")); instructions.push(Instruction::Label(label.unwrap() + "_break")); Ok(()) } - parse::Statement::Continue(label) => { + ast::Statement::Continue(label) => { instructions.push(Instruction::Jump(label.unwrap() + "_continue")); Ok(()) } - parse::Statement::Break(label) => { + ast::Statement::Break(label) => { instructions.push(Instruction::Jump(label.unwrap() + "_break")); Ok(()) } - parse::Statement::Switch(expression, statement, label, items) => { - let switch_operand = parse_expression(expression, instructions)?; + ast::Statement::Switch(expression, statement, label, items) => { + let expr_type = get_expression_type(&expression)?.clone(); + let switch_operand = parse_expression(expression, instructions, symbol_table)?; let mut default_case: Option<String> = None; for item in items {@@ -464,7 +575,7 @@ if item.0 == None {
default_case = Some(item.1); continue; } - let cmp_tmp = gen_temp(); + let cmp_tmp = gen_temp(Some(expr_type.clone()), symbol_table)?; if let Some(constant) = item.0 { instructions.push(Instruction::Binary( BinOp::Equal,@@ -482,33 +593,34 @@ } else {
instructions.push(Instruction::Jump(label.clone().unwrap() + "_break")); } - parse_statement(*statement, instructions)?; + parse_statement(*statement, instructions, symbol_table)?; instructions.push(Instruction::Label(label.unwrap() + "_break")); Ok(()) } - parse::Statement::Default(statement, label) => { + ast::Statement::Default(statement, label) => { instructions.push(Instruction::Label(label.unwrap())); - parse_statement(*statement, instructions)?; + parse_statement(*statement, instructions, symbol_table)?; Ok(()) } - parse::Statement::Case(_, statement, label) => { + ast::Statement::Case(_, statement, label) => { instructions.push(Instruction::Label(label.unwrap())); - parse_statement(*statement, instructions)?; + parse_statement(*statement, instructions, symbol_table)?; Ok(()) } } } fn parse_variable_declaration( - decl: parse::VariableDeclaration, + decl: ast::VariableDeclaration, instructions: &mut Vec<Instruction>, + symbol_table: &mut HashMap<String, (Type, IdentifierAttributes)>, ) -> Result<()> { match decl { - parse::VariableDeclaration::D(id, opt_expression, storage_class) => { + ast::VariableDeclaration::D(id, opt_expression, _, storage_class) => { if storage_class == None && let Some(expression) = opt_expression { - let src = parse_expression(expression, instructions)?; + let src = parse_expression(expression, instructions, symbol_table)?; instructions.push(Instruction::Copy(src, Operand::Variable(id))); } Ok(())@@ -516,27 +628,41 @@ }
} } -fn parse_declaration(decl: parse::Declaration, instructions: &mut Vec<Instruction>) -> Result<()> { +fn parse_declaration( + decl: ast::Declaration, + instructions: &mut Vec<Instruction>, + symbol_table: &mut HashMap<String, (Type, IdentifierAttributes)>, +) -> Result<()> { match decl { - parse::Declaration::V(variable_declaration) => { - parse_variable_declaration(variable_declaration, instructions) + ast::Declaration::V(variable_declaration) => { + parse_variable_declaration(variable_declaration, instructions, symbol_table) } - parse::Declaration::F(_) => Ok(()), + ast::Declaration::F(_) => Ok(()), } } -fn parse_block_item(bl: parse::BlockItem, instructions: &mut Vec<Instruction>) -> Result<()> { +fn parse_block_item( + bl: ast::BlockItem, + instructions: &mut Vec<Instruction>, + symbol_table: &mut HashMap<String, (Type, IdentifierAttributes)>, +) -> Result<()> { match bl { - parse::BlockItem::S(statement) => parse_statement(statement, instructions), - parse::BlockItem::D(declaration) => parse_declaration(declaration, instructions), + ast::BlockItem::S(statement) => parse_statement(statement, instructions, symbol_table), + ast::BlockItem::D(declaration) => { + parse_declaration(declaration, instructions, symbol_table) + } } } -fn parse_block(bl: parse::Block, instructions: &mut Vec<Instruction>) -> Result<()> { +fn parse_block( + bl: ast::Block, + instructions: &mut Vec<Instruction>, + symbol_table: &mut HashMap<String, (Type, IdentifierAttributes)>, +) -> Result<()> { match bl { - parse::Block::B(block_items) => { + ast::Block::B(block_items) => { for block in block_items { - parse_block_item(block, instructions)?; + parse_block_item(block, instructions, symbol_table)?; } Ok(()) }@@ -544,15 +670,15 @@ }
} fn parse_function_declaration( - fun: parse::FunctionDeclaration, - symbol_table: &HashMap<String, (Type, IdentifierAttributes)>, + fun: ast::FunctionDeclaration, + symbol_table: &mut HashMap<String, (Type, IdentifierAttributes)>, ) -> Result<Option<TopLevel>> { let mut instructions = Vec::new(); match fun { - parse::FunctionDeclaration::D(name, params, body, _) => { + ast::FunctionDeclaration::D(name, params, body, _, _) => { if let Some(bl) = body { - parse_block(bl, &mut instructions)?; - instructions.push(Instruction::Ret(Operand::Constant(0))); + parse_block(bl, &mut instructions, symbol_table)?; + instructions.push(Instruction::Ret(Operand::Constant(ast::Const::Int(0)))); // TODO: write symbol_table util functions (is_global, is_static...) let global = match symbol_table.get(&name) {@@ -577,15 +703,24 @@ fn convert_symbols_to_ir(
symbol_table: &HashMap<String, (Type, IdentifierAttributes)>, ) -> Result<Vec<TopLevel>> { let mut result = Vec::new(); - for (name, (_, id_attr)) in symbol_table { + for (name, (var_type, id_attr)) in symbol_table { match id_attr { IdentifierAttributes::StaticAttributes(initial_value, global) => match initial_value { super::type_check::InitialValue::Tentative => { - result.push(TopLevel::StaticVariable(name.clone(), *global, 0)) + result.push(TopLevel::StaticVariable( + name.clone(), + *global, + var_type.clone(), + match var_type { + Type::Int => StaticInit::IntInit(0), + Type::Long => StaticInit::LongInit(0), + Type::Function(_, _) => bail!("No function allowed here"), + }, + )) } - super::type_check::InitialValue::Initial(i) => { - result.push(TopLevel::StaticVariable(name.clone(), *global, *i)) - } + super::type_check::InitialValue::Initial(i) => result.push( + TopLevel::StaticVariable(name.clone(), *global, var_type.clone(), i.clone()), + ), super::type_check::InitialValue::NoInit => (), }, _ => (),@@ -596,18 +731,18 @@ Ok(result)
} pub fn lift_to_ir( - prog: parse::Ast, - symbol_table: &HashMap<String, (Type, IdentifierAttributes)>, + prog: ast::Ast, + symbol_table: &mut HashMap<String, (Type, IdentifierAttributes)>, ) -> Result<TAC> { COUNTER_TMP.store(0, Ordering::SeqCst); COUNTER_LABEL.store(0, Ordering::SeqCst); match prog { - parse::Ast::Program(functions) => { + ast::Ast::Program(functions) => { let mut top_level_ir = Vec::new(); for func in functions { match func { - parse::Declaration::V(_) => (), - parse::Declaration::F(function_declaration) => { + ast::Declaration::V(_) => (), + ast::Declaration::F(function_declaration) => { if let Some(function) = parse_function_declaration(function_declaration, symbol_table)? {
@@ -3,7 +3,7 @@ collections::HashMap,
sync::atomic::{AtomicUsize, Ordering}, }; -use crate::frontend::parse::{Ast, Block, BlockItem, Declaration, FunctionDeclaration, Statement}; +use crate::frontend::ast::{Ast, Block, BlockItem, Declaration, FunctionDeclaration, Statement}; use anyhow::{Result, bail}; fn gen_label(id: String) -> String {@@ -189,12 +189,13 @@ decl: Declaration,
hash_map: &mut HashMap<String, String>, ) -> Result<Declaration> { match decl { - Declaration::F(FunctionDeclaration::D(name, args, block, storage_class)) => { + Declaration::F(FunctionDeclaration::D(name, args, block, var_type, storage_class)) => { let Some(bl) = block else { return Ok(Declaration::F(FunctionDeclaration::D( name, args, block, + var_type, storage_class, ))); };@@ -204,6 +205,7 @@ Ok(Declaration::F(FunctionDeclaration::D(
name, args, block, + var_type, storage_class, ))) }
@@ -6,8 +6,10 @@
#[derive(Debug, PartialEq)] pub enum Token { Identifier(String), - Constant(i32), + IntConstant(i64), + LongConstant(i64), Int, + Long, Void, Return, OpenParanthesis,@@ -67,70 +69,76 @@ Extern,
} impl Token { - pub fn patterns() -> Vec<(Regex, fn(&str) -> Token)> { + pub fn patterns() -> Vec<(Regex, fn(&str) -> Result<Token>)> { vec![ - (Regex::new(r"<<\=").unwrap(), |_| Token::CLShift), - (Regex::new(r">>\=").unwrap(), |_| Token::CRShift), - (Regex::new(r"\-\=").unwrap(), |_| Token::CNegation), - (Regex::new(r"\+\=").unwrap(), |_| Token::CAddition), - (Regex::new(r"\^\=").unwrap(), |_| Token::CXor), - (Regex::new(r"\|\=").unwrap(), |_| Token::COr), - (Regex::new(r"\&\=").unwrap(), |_| Token::CAnd), - (Regex::new(r"\*\=").unwrap(), |_| Token::CMultiplication), - (Regex::new(r"\/\=").unwrap(), |_| Token::CDivision), - (Regex::new(r"\%\=").unwrap(), |_| Token::CModulo), - (Regex::new(r"\&\&").unwrap(), |_| Token::LAnd), - (Regex::new(r"\|\|").unwrap(), |_| Token::LOr), - (Regex::new(r"\=\=").unwrap(), |_| Token::Equal), - (Regex::new(r"\!\=").unwrap(), |_| Token::NEqual), - (Regex::new(r"<<").unwrap(), |_| Token::LShift), - (Regex::new(r">>").unwrap(), |_| Token::RShift), - (Regex::new(r"<\=").unwrap(), |_| Token::LessEq), - (Regex::new(r">\=").unwrap(), |_| Token::GreaterEq), - (Regex::new(r"\!").unwrap(), |_| Token::Not), - (Regex::new(r"\+\+").unwrap(), |_| Token::Increment), - (Regex::new(r"\-\-").unwrap(), |_| Token::Decrement), - (Regex::new(r"\?").unwrap(), |_| Token::QuestionMark), - (Regex::new(r"\:").unwrap(), |_| Token::Colon), - (Regex::new(r"\,").unwrap(), |_| Token::Comma), - (Regex::new(r"<").unwrap(), |_| Token::Less), - (Regex::new(r">").unwrap(), |_| Token::Greater), - (Regex::new(r"\-").unwrap(), |_| Token::Negation), - (Regex::new(r"\=").unwrap(), |_| Token::Assignment), - (Regex::new(r"\+").unwrap(), |_| Token::Addition), - (Regex::new(r"\^").unwrap(), |_| Token::Xor), - (Regex::new(r"\|").unwrap(), |_| Token::Or), - (Regex::new(r"\&").unwrap(), |_| Token::And), - (Regex::new(r"\*").unwrap(), |_| Token::Multiplication), - (Regex::new(r"\/").unwrap(), |_| Token::Division), - (Regex::new(r"\%").unwrap(), |_| Token::Modulo), - (Regex::new(r"\~").unwrap(), |_| Token::Complement), - (Regex::new(r"\{").unwrap(), |_| Token::OpenBrace), - (Regex::new(r"\}").unwrap(), |_| Token::CloseBrace), - (Regex::new(r"\(").unwrap(), |_| Token::OpenParanthesis), - (Regex::new(r"\)").unwrap(), |_| Token::CloseParanthesis), - (Regex::new(r"\;").unwrap(), |_| Token::Semicolon), - (Regex::new(r"return\b").unwrap(), |_| Token::Return), - (Regex::new(r"void\b").unwrap(), |_| Token::Void), - (Regex::new(r"static\b").unwrap(), |_| Token::Static), - (Regex::new(r"extern\b").unwrap(), |_| Token::Extern), - (Regex::new(r"int\b").unwrap(), |_| Token::Int), - (Regex::new(r"if\b").unwrap(), |_| Token::If), - (Regex::new(r"do\b").unwrap(), |_| Token::Do), - (Regex::new(r"while\b").unwrap(), |_| Token::While), - (Regex::new(r"for\b").unwrap(), |_| Token::For), - (Regex::new(r"break\b").unwrap(), |_| Token::Break), - (Regex::new(r"continue\b").unwrap(), |_| Token::Continue), - (Regex::new(r"else\b").unwrap(), |_| Token::Else), - (Regex::new(r"goto\b").unwrap(), |_| Token::Goto), - (Regex::new(r"switch\b").unwrap(), |_| Token::Switch), - (Regex::new(r"case\b").unwrap(), |_| Token::Case), - (Regex::new(r"default\b").unwrap(), |_| Token::Default), + (Regex::new(r"<<\=").unwrap(), |_| Ok(Token::CLShift)), + (Regex::new(r">>\=").unwrap(), |_| Ok(Token::CRShift)), + (Regex::new(r"\-\=").unwrap(), |_| Ok(Token::CNegation)), + (Regex::new(r"\+\=").unwrap(), |_| Ok(Token::CAddition)), + (Regex::new(r"\^\=").unwrap(), |_| Ok(Token::CXor)), + (Regex::new(r"\|\=").unwrap(), |_| Ok(Token::COr)), + (Regex::new(r"\&\=").unwrap(), |_| Ok(Token::CAnd)), + (Regex::new(r"\*\=").unwrap(), |_| Ok(Token::CMultiplication)), + (Regex::new(r"\/\=").unwrap(), |_| Ok(Token::CDivision)), + (Regex::new(r"\%\=").unwrap(), |_| Ok(Token::CModulo)), + (Regex::new(r"\&\&").unwrap(), |_| Ok(Token::LAnd)), + (Regex::new(r"\|\|").unwrap(), |_| Ok(Token::LOr)), + (Regex::new(r"\=\=").unwrap(), |_| Ok(Token::Equal)), + (Regex::new(r"\!\=").unwrap(), |_| Ok(Token::NEqual)), + (Regex::new(r"<<").unwrap(), |_| Ok(Token::LShift)), + (Regex::new(r">>").unwrap(), |_| Ok(Token::RShift)), + (Regex::new(r"<\=").unwrap(), |_| Ok(Token::LessEq)), + (Regex::new(r">\=").unwrap(), |_| Ok(Token::GreaterEq)), + (Regex::new(r"\!").unwrap(), |_| Ok(Token::Not)), + (Regex::new(r"\+\+").unwrap(), |_| Ok(Token::Increment)), + (Regex::new(r"\-\-").unwrap(), |_| Ok(Token::Decrement)), + (Regex::new(r"\?").unwrap(), |_| Ok(Token::QuestionMark)), + (Regex::new(r"\:").unwrap(), |_| Ok(Token::Colon)), + (Regex::new(r"\,").unwrap(), |_| Ok(Token::Comma)), + (Regex::new(r"<").unwrap(), |_| Ok(Token::Less)), + (Regex::new(r">").unwrap(), |_| Ok(Token::Greater)), + (Regex::new(r"\-").unwrap(), |_| Ok(Token::Negation)), + (Regex::new(r"\=").unwrap(), |_| Ok(Token::Assignment)), + (Regex::new(r"\+").unwrap(), |_| Ok(Token::Addition)), + (Regex::new(r"\^").unwrap(), |_| Ok(Token::Xor)), + (Regex::new(r"\|").unwrap(), |_| Ok(Token::Or)), + (Regex::new(r"\&").unwrap(), |_| Ok(Token::And)), + (Regex::new(r"\*").unwrap(), |_| Ok(Token::Multiplication)), + (Regex::new(r"\/").unwrap(), |_| Ok(Token::Division)), + (Regex::new(r"\%").unwrap(), |_| Ok(Token::Modulo)), + (Regex::new(r"\~").unwrap(), |_| Ok(Token::Complement)), + (Regex::new(r"\{").unwrap(), |_| Ok(Token::OpenBrace)), + (Regex::new(r"\}").unwrap(), |_| Ok(Token::CloseBrace)), + (Regex::new(r"\(").unwrap(), |_| Ok(Token::OpenParanthesis)), + (Regex::new(r"\)").unwrap(), |_| Ok(Token::CloseParanthesis)), + (Regex::new(r"\;").unwrap(), |_| Ok(Token::Semicolon)), + (Regex::new(r"return\b").unwrap(), |_| Ok(Token::Return)), + (Regex::new(r"void\b").unwrap(), |_| Ok(Token::Void)), + (Regex::new(r"static\b").unwrap(), |_| Ok(Token::Static)), + (Regex::new(r"extern\b").unwrap(), |_| Ok(Token::Extern)), + (Regex::new(r"int\b").unwrap(), |_| Ok(Token::Int)), + (Regex::new(r"long\b").unwrap(), |_| Ok(Token::Long)), + (Regex::new(r"if\b").unwrap(), |_| Ok(Token::If)), + (Regex::new(r"do\b").unwrap(), |_| Ok(Token::Do)), + (Regex::new(r"while\b").unwrap(), |_| Ok(Token::While)), + (Regex::new(r"for\b").unwrap(), |_| Ok(Token::For)), + (Regex::new(r"break\b").unwrap(), |_| Ok(Token::Break)), + (Regex::new(r"continue\b").unwrap(), |_| Ok(Token::Continue)), + (Regex::new(r"else\b").unwrap(), |_| Ok(Token::Else)), + (Regex::new(r"goto\b").unwrap(), |_| Ok(Token::Goto)), + (Regex::new(r"switch\b").unwrap(), |_| Ok(Token::Switch)), + (Regex::new(r"case\b").unwrap(), |_| Ok(Token::Case)), + (Regex::new(r"default\b").unwrap(), |_| Ok(Token::Default)), + (Regex::new(r"[0-9]+[lL]\b").unwrap(), |s| { + let cons = s[..s.len() - 1].parse::<i64>()?; + Ok(Token::LongConstant(cons)) + }), (Regex::new(r"[0-9]+\b").unwrap(), |s| { - Token::Constant(s.parse::<i32>().unwrap()) + let cons = s.parse::<i64>()?; + Ok(Token::IntConstant(cons)) }), (Regex::new(r"[a-zA-Z_]\w*\b").unwrap(), |s| { - Token::Identifier(s.to_string()) + Ok(Token::Identifier(s.to_string())) }), ] }@@ -140,8 +148,10 @@ impl fmt::Display for Token {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Token::Identifier(x) => write!(f, "{}", x), - Token::Constant(x) => write!(f, "{}", x), + Token::IntConstant(x) => write!(f, "{}", x), + Token::LongConstant(x) => write!(f, "{}", x), Token::Int => write!(f, "int"), + Token::Long => write!(f, "long"), Token::Void => write!(f, "void"), Token::Return => write!(f, "return"), Token::OpenParanthesis => write!(f, "("),@@ -251,6 +261,45 @@ _ => false,
} } +pub fn is_assignment(tok: &Token) -> bool { + match tok { + Token::Assignment + | Token::CAddition + | Token::CNegation + | Token::CMultiplication + | Token::CDivision + | Token::CModulo + | Token::CAnd + | Token::COr + | Token::CXor + | Token::CLShift + | Token::CRShift => true, + _ => false, + } +} + +pub fn is_specifier(token: &Token) -> bool { + match token { + Token::Extern | Token::Static => true, + n if is_type_specifier(n) => true, + _ => false, + } +} + +pub fn is_type_specifier(tok: &Token) -> bool { + match tok { + Token::Int | Token::Long => true, + _ => false, + } +} + +pub fn is_constant(tok: &Token) -> bool { + match tok { + Token::IntConstant(_) | Token::LongConstant(_) => true, + _ => false, + } +} + pub fn precedence(token: &Token) -> usize { match token { Token::Multiplication | Token::Division | Token::Modulo => 50,@@ -287,7 +336,7 @@ 'outer: while !remaining.is_empty() {
for (re, tok) in &patterns { if let Some(m) = re.find(remaining) { if !m.is_empty() && m.start() == 0 { - tokens.push_back(tok(m.as_str())); + tokens.push_back(tok(m.as_str())?); remaining = remaining[m.end()..].trim_start(); continue 'outer; }
@@ -3,8 +3,8 @@ collections::HashMap,
sync::atomic::{AtomicUsize, Ordering}, }; -use crate::frontend::parse::{ - Ast, Block, BlockItem, Declaration, Expression, FunctionDeclaration, Statement, +use crate::frontend::ast::{ + Ast, Block, BlockItem, Const, Declaration, Expression, FunctionDeclaration, Statement, }; use anyhow::{Result, bail};@@ -18,7 +18,7 @@ statement: Statement,
hash_map: &mut HashMap<String, String>, current_loop: Option<String>, current_switch: Option<String>, - collected_cases: &mut Option<Vec<(Option<i32>, String)>>, + collected_cases: &mut Option<Vec<(Option<Const>, String)>>, ) -> Result<Statement> { match statement { Statement::Labeled(id, statement) => Ok(Statement::Labeled(@@ -131,18 +131,18 @@ let Some(unwrapped_cases) = collected_cases.as_mut() else {
bail!("Case not inside a switch statement"); }; match expression { - Expression::Constant(c) => { + Expression::Constant(val, _) => { if unwrapped_cases.iter().any(|(constant, _)| { if let Some(con) = constant { - return *con == c; + return *con == val; } false }) { - bail!("Case {} is duplicated", c); + bail!("Case {:?} is duplicated", val); } - unwrapped_cases.push((Some(c), n_label.clone())); + unwrapped_cases.push((Some(val.clone()), n_label.clone())); Ok(Statement::Case( - expression, + Expression::Constant(val, None), Box::new(resolve_loop_statement( *statement, hash_map,@@ -158,7 +158,7 @@ }
} Statement::Switch(expression, statement, _, _) => { let n_label = Some(gen_label("switch".to_string())); - let mut new_cases: Option<Vec<(Option<i32>, String)>> = Some(Vec::new()); + let mut new_cases: Option<Vec<(Option<Const>, String)>> = Some(Vec::new()); Ok(Statement::Switch( expression, Box::new(resolve_loop_statement(@@ -204,7 +204,7 @@ block: Block,
hash_map: &mut HashMap<String, String>, current_loop: Option<String>, current_switch: Option<String>, - collected_cases: &mut Option<Vec<(Option<i32>, String)>>, + collected_cases: &mut Option<Vec<(Option<Const>, String)>>, ) -> Result<Block> { match block { Block::B(block_items) => {@@ -232,12 +232,13 @@ decl: Declaration,
hash_map: &mut HashMap<String, String>, ) -> Result<Declaration> { match decl { - Declaration::F(FunctionDeclaration::D(name, args, block, storage_class)) => { + Declaration::F(FunctionDeclaration::D(name, args, block, var_type, storage_class)) => { let Some(bl) = block else { return Ok(Declaration::F(FunctionDeclaration::D( name, args, block, + var_type, storage_class, ))); };@@ -246,6 +247,7 @@ Ok(Declaration::F(FunctionDeclaration::D(
name, args, block, + var_type, storage_class, ))) }
@@ -1,131 +1,10 @@
use crate::frontend::{ + ast::*, lex::{self, Token}, - type_check::Type, }; use anyhow::{Result, bail}; use std::collections::VecDeque; -#[derive(Debug, PartialEq)] -pub enum Ast { - Program(Vec<Declaration>), -} - -#[derive(Debug, PartialEq)] -pub enum Declaration { - V(VariableDeclaration), - F(FunctionDeclaration), -} - -#[derive(Debug, PartialEq)] -pub enum Block { - B(Vec<BlockItem>), -} - -#[derive(Debug, PartialEq)] -pub enum BlockItem { - S(Statement), - D(Declaration), -} - -#[derive(Debug, PartialEq)] -pub enum VariableDeclaration { - D(String, Option<Expression>, Option<StorageClass>), -} - -#[derive(Debug, PartialEq)] -pub enum FunctionDeclaration { - D(String, Vec<String>, Option<Block>, Option<StorageClass>), -} - -#[derive(Debug, PartialEq)] -pub enum StorageClass { - Static, - Extern, -} - -#[derive(Debug, PartialEq)] -pub enum Statement { - Return(Expression), - Expression(Expression), - If(Expression, Box<Statement>, Option<Box<Statement>>), - While(Expression, Box<Statement>, Option<String>), - DoWhile(Box<Statement>, Expression, Option<String>), - For( - ForInit, - Option<Expression>, - Option<Expression>, - Box<Statement>, - Option<String>, - ), - Switch( - Expression, - Box<Statement>, - Option<String>, - Vec<(Option<i32>, String)>, - ), - Default(Box<Statement>, Option<String>), - Case(Expression, Box<Statement>, Option<String>), - Continue(Option<String>), - Break(Option<String>), - Compound(Block), - Labeled(String, Box<Statement>), - Goto(String), - Null, -} - -#[derive(Debug, PartialEq)] -pub enum ForInit { - D(VariableDeclaration), - E(Option<Expression>), -} - -#[derive(Debug, PartialEq)] -pub enum Expression { - Constant(i32), - Variable(String), - Unary(UnaryOp, Box<Expression>), - Binary(BinaryOp, Box<Expression>, Box<Expression>), - Assignment(Box<Expression>, Box<Expression>), - CompoundAssignment(BinaryOp, Box<Expression>, Box<Expression>), - PostIncr(Box<Expression>), - PostDecr(Box<Expression>), - Conditional(Box<Expression>, Box<Expression>, Box<Expression>), - FunctionCall(String, Vec<Box<Expression>>), -} - -#[derive(Debug, PartialEq)] -pub enum UnaryOp { - Complement, - Negation, - Not, - Increment, - Decrement, -} - -#[derive(Debug, PartialEq)] -pub enum BinaryOp { - Addition, - Subtraction, - Multiplication, - Division, - Modulo, - And, - Or, - Xor, - LShift, - RShift, - LAnd, - LOr, - Equal, - NEqual, - Less, - Greater, - LessEq, - GreaterEq, - - Assignment, -} - fn get_compound_operator(tok: Option<Token>) -> Result<Option<BinaryOp>> { match tok { Some(Token::Assignment) => Ok(None),@@ -144,30 +23,14 @@ None => bail!("End of file"),
} } -fn is_assignment(tok: &Token) -> bool { - match tok { - Token::Assignment - | Token::CAddition - | Token::CNegation - | Token::CMultiplication - | Token::CDivision - | Token::CModulo - | Token::CAnd - | Token::COr - | Token::CXor - | Token::CLShift - | Token::CRShift => true, - _ => false, - } -} - fn expect_token(expect: Token, actual: Option<Token>) -> Result<()> { match actual { Some(x) => { if expect != x { match x { Token::Identifier(val) => bail!("Expected {} but found {}", expect, val), - Token::Constant(val) => bail!("Expected {} but found {}", expect, val), + Token::IntConstant(val) => bail!("Expected {} but found {}", expect, val), + Token::LongConstant(val) => bail!("Expected {} but found {}", expect, val), val => bail!("Expected {} but found {}", expect, val), } }@@ -177,10 +40,17 @@ None => bail!("Expected {}", expect),
} } -/* constant ::= An i32 */ -fn parse_constant(tokens: &mut VecDeque<Token>) -> Result<i32> { +/* constant ::= An int or a long */ +fn parse_constant(tokens: &mut VecDeque<Token>) -> Result<Const> { match tokens.pop_front() { - Some(Token::Constant(x)) => Ok(x), + Some(Token::IntConstant(x)) => { + if x <= 2_i64.pow(31) - 1 { + Ok(Const::Int(x as i32)) + } else { + Ok(Const::Long(x)) + } + } + Some(Token::LongConstant(x)) => Ok(Const::Long(x)), Some(x) => bail!("Expected a constant but found {}", x), None => bail!("Expected a constant but file ended"), }@@ -195,6 +65,25 @@ None => bail!("Expected an identifier but file ended"),
} } +/* type_specifier ::= "int" | "long" */ +fn parse_type_specifier(tokens: &mut VecDeque<Token>) -> Result<Type> { + match tokens.pop_front() { + Some(Token::Int) => Ok(Type::Int), + Some(Token::Long) => Ok(Type::Long), + Some(x) => bail!("Expected type specifier but got {}", x), + None => bail!("Expected a type specifier but file ended"), + } +} + +/* type_specifier_list ::= { type_specifier }+ */ +fn parse_type_specifier_list(tokens: &mut VecDeque<Token>) -> Result<Type> { + let mut types = Vec::new(); + while lex::is_type_specifier(&tokens[0]) { + types.push(parse_type_specifier(tokens)?); + } + parse_type_helper(types) +} + /* unop ::= "~" | "-" | ... | "--" */ fn parse_unop(tokens: &mut VecDeque<Token>) -> Result<UnaryOp> { match tokens.pop_front() {@@ -238,9 +127,9 @@
/* first_expr ::= <constant> | <identifier> | "(" <exp> ")" */ fn parse_first_expr(tokens: &mut VecDeque<Token>) -> Result<Expression> { match &tokens[0] { - Token::Constant(_) => { + x if lex::is_constant(x) => { let constant = parse_constant(tokens)?; - Ok(Expression::Constant(constant)) + Ok(Expression::Constant(constant, None)) } Token::OpenParanthesis => { expect_token(Token::OpenParanthesis, tokens.pop_front())?;@@ -250,7 +139,7 @@ Ok(expr)
} Token::Identifier(_) => { let id = parse_identifier(tokens)?; - Ok(Expression::Variable(id)) + Ok(Expression::Variable(id, None)) } x => bail!("Broken expression: got {}", x), }@@ -261,18 +150,19 @@ fn parse_postfixes(tokens: &mut VecDeque<Token>, expr: Expression) -> Result<Expression> {
match &tokens[0] { Token::Increment => { expect_token(Token::Increment, tokens.pop_front())?; - parse_postfixes(tokens, Expression::PostIncr(Box::new(expr))) + parse_postfixes(tokens, Expression::PostIncr(Box::new(expr), None)) } Token::Decrement => { expect_token(Token::Decrement, tokens.pop_front())?; - parse_postfixes(tokens, Expression::PostDecr(Box::new(expr))) + parse_postfixes(tokens, Expression::PostDecr(Box::new(expr), None)) } _ => Ok(expr), } } /* factor ::= <unop> <factor> | <first_expr> <postfixes> - * | <identifier> "(" [ <exp> { "," <exp> } ] ")" */ + * | <identifier> "(" [ <exp> { "," <exp> } ] ")" + * | "(" <type_specifier> ")" <factor> */ fn parse_factor(tokens: &mut VecDeque<Token>) -> Result<Expression> { match (&tokens[0], &tokens[1]) { (Token::Complement, _)@@ -282,7 +172,7 @@ | (Token::Increment, _)
| (Token::Decrement, _) => { let unop = parse_unop(tokens)?; let expr = parse_factor(tokens)?; - Ok(Expression::Unary(unop, Box::new(expr))) + Ok(Expression::Unary(unop, Box::new(expr), None)) } (Token::Identifier(_), Token::OpenParanthesis) => { let id = parse_identifier(tokens)?;@@ -296,7 +186,14 @@ args.push(Box::new(parse_expression(tokens, 0)?));
} expect_token(Token::CloseParanthesis, tokens.pop_front())?; - Ok(Expression::FunctionCall(id, args)) + Ok(Expression::FunctionCall(id, args, None)) + } + (Token::OpenParanthesis, n) if lex::is_type_specifier(n) => { + expect_token(Token::OpenParanthesis, tokens.pop_front())?; + let type_spec = parse_type_specifier(tokens)?; + expect_token(Token::CloseParanthesis, tokens.pop_front())?; + let factor = parse_factor(tokens)?; + Ok(Expression::Cast(type_spec, Box::new(factor), None)) } _ => { let expr = parse_first_expr(tokens)?;@@ -311,23 +208,23 @@ fn parse_expression(tokens: &mut VecDeque<Token>, order: usize) -> Result<Expression> {
let mut left = parse_factor(tokens)?; while lex::is_binary(&tokens[0]) && lex::precedence(&tokens[0]) >= order { let prec = lex::precedence(&tokens[0]); - if is_assignment(&tokens[0]) { + if lex::is_assignment(&tokens[0]) { let compound_op = tokens.pop_front(); let right = parse_expression(tokens, prec)?; left = match get_compound_operator(compound_op)? { - None => Expression::Assignment(Box::new(left), Box::new(right)), - Some(x) => Expression::CompoundAssignment(x, Box::new(left), Box::new(right)), + None => Expression::Assignment(Box::new(left), Box::new(right), None), + Some(x) => Expression::CompoundAssignment(x, Box::new(left), Box::new(right), None), }; } else if matches!(tokens[0], Token::QuestionMark) { expect_token(Token::QuestionMark, tokens.pop_front())?; let middle = parse_expression(tokens, 0)?; expect_token(Token::Colon, tokens.pop_front())?; let right = parse_expression(tokens, prec)?; - left = Expression::Conditional(Box::new(left), Box::new(middle), Box::new(right)); + left = Expression::Conditional(Box::new(left), Box::new(middle), Box::new(right), None); } else { let op = parse_binop(tokens)?; let right = parse_expression(tokens, prec + 1)?; - left = Expression::Binary(op, Box::new(left), Box::new(right)); + left = Expression::Binary(op, Box::new(left), Box::new(right), None); } } Ok(left)@@ -335,16 +232,16 @@ }
/* for_init ::= <declaration> | [ <expression> ] ";" */ fn parse_for_init(tokens: &mut VecDeque<Token>) -> Result<ForInit> { - match &tokens[0] { - Token::Int | Token::Extern | Token::Static => { - let declaration = parse_declaration(tokens)?; - match declaration { - Declaration::V(variable_declaration) => Ok(ForInit::D(variable_declaration)), - Declaration::F(_) => { - bail!("Function declaration is not allowed in initializer of for loop") - } + if lex::is_specifier(&tokens[0]) { + let declaration = parse_declaration(tokens)?; + match declaration { + Declaration::V(variable_declaration) => return Ok(ForInit::D(variable_declaration)), + Declaration::F(_) => { + bail!("Function declaration is not allowed in initializer of for loop"); } } + } + match &tokens[0] { Token::Semicolon => { expect_token(Token::Semicolon, tokens.pop_front())?; Ok(ForInit::E(None))@@ -498,9 +395,10 @@ }
/* block-item ::= <statement> | <declaration> */ fn parse_block_item(tokens: &mut VecDeque<Token>) -> Result<BlockItem> { - match &tokens[0] { - Token::Int | Token::Extern | Token::Static => Ok(BlockItem::D(parse_declaration(tokens)?)), - _ => Ok(BlockItem::S(parse_statement(tokens)?)), + if lex::is_specifier(&tokens[0]) { + Ok(BlockItem::D(parse_declaration(tokens)?)) + } else { + Ok(BlockItem::S(parse_statement(tokens)?)) } }@@ -515,73 +413,79 @@ expect_token(Token::CloseBrace, tokens.pop_front())?;
Ok(Block::B(body)) } -/* param-list ::= "void" | "int" <identifier> { "," "int" <identifier> } */ -fn parse_param_list(tokens: &mut VecDeque<Token>) -> Result<Vec<String>> { +/* param-list ::= "void" | <type_specifier> <identifier> { "," <type_specifier> <identifier> } */ +fn parse_param_list(tokens: &mut VecDeque<Token>) -> Result<(Vec<Type>, Vec<String>)> { match &tokens[0] { Token::Void => { expect_token(Token::Void, tokens.pop_front())?; - Ok(Vec::new()) + Ok((Vec::new(), Vec::new())) } - Token::Int => { - expect_token(Token::Int, tokens.pop_front())?; - let mut result = Vec::new(); - result.push(parse_identifier(tokens)?); + n if lex::is_type_specifier(n) => { + let mut types = Vec::new(); + let mut ids = Vec::new(); + types.push(parse_type_specifier_list(tokens)?); + ids.push(parse_identifier(tokens)?); while tokens[0] == Token::Comma { expect_token(Token::Comma, tokens.pop_front())?; - expect_token(Token::Int, tokens.pop_front())?; - result.push(parse_identifier(tokens)?); + types.push(parse_type_specifier_list(tokens)?); + ids.push(parse_identifier(tokens)?); } - Ok(result) + Ok((types, ids)) } x => bail!("Wrong token in parameter list: {}", x), } } -fn is_specifier(token: &Token) -> bool { - match token { - Token::Int | Token::Extern | Token::Static => true, - _ => false, +/* Helper for extracting the correct type from different combinations */ +fn parse_type_helper(types: Vec<Type>) -> Result<Type> { + match types[..] { + [Type::Int] => Ok(Type::Int), + [Type::Long] | [Type::Int, Type::Long] | [Type::Long, Type::Int] => Ok(Type::Long), + _ => bail!("Invalid type specifier {:?}", types), + } +} + +/* storage_class_specifier ::= "static" | "extern" */ +fn parse_storage_class_specifier(tokens: &mut VecDeque<Token>) -> Result<StorageClass> { + match tokens.pop_front() { + Some(Token::Extern) => Ok(StorageClass::Extern), + Some(Token::Static) => Ok(StorageClass::Static), + Some(x) => bail!("Expected a storage class specifier but found {}", x), + None => bail!("Expected a storage class specifier but file ended"), } } -/* specifier ::= { "int" | "static" | "extern" }+ */ -fn parse_specifier(tokens: &mut VecDeque<Token>) -> Result<(Type, Option<StorageClass>)> { +/* specifier_list ::= { <type_specifier> | <storage_class_specifier> }+ */ +fn parse_specifier_list(tokens: &mut VecDeque<Token>) -> Result<(Type, Option<StorageClass>)> { let mut types: Vec<Type> = Vec::new(); let mut storage_classes: Vec<StorageClass> = Vec::new(); - while is_specifier(&tokens[0]) { - match tokens.pop_front() { - Some(Token::Int) => types.push(Type::Int), - Some(Token::Extern) => storage_classes.push(StorageClass::Extern), - Some(Token::Static) => storage_classes.push(StorageClass::Static), - x => bail!("Expected specifier, got {:?}", x), + while lex::is_specifier(&tokens[0]) { + if lex::is_type_specifier(&tokens[0]) { + types.push(parse_type_specifier(tokens)?); + } else { + storage_classes.push(parse_storage_class_specifier(tokens)?); } } - if types.len() != 1 { - bail!("Invalid type specifier: {:?}", types); - } - if storage_classes.len() > 1 { bail!("Invalid storage class specifier: {:?}", storage_classes); } - let Some(type_specifier) = types.pop() else { - bail!("Invalid type specifier: {:?}", types); - }; + let type_specifier = parse_type_helper(types)?; Ok((type_specifier, storage_classes.pop())) } /* declaration ::= <specifier> ( <variable_declaration> | <function_declaration> ) */ -/* function_declaration ::= "int" <identifier> "(" <param-list> ")" ( <block> | ";" ) */ -/* variable_declaration ::= "int" <identifier> [ "=" <expression> ] ";" */ +/* function_declaration ::= <identifier> "(" <param-list> ")" ( <block> | ";" ) */ +/* variable_declaration ::= <identifier> [ "=" <expression> ] ";" */ fn parse_declaration(tokens: &mut VecDeque<Token>) -> Result<Declaration> { - let specifier = parse_specifier(tokens)?; + let specifier = parse_specifier_list(tokens)?; let id = parse_identifier(tokens)?; match &tokens[0] { Token::OpenParanthesis => { expect_token(Token::OpenParanthesis, tokens.pop_front())?; - let param_list = parse_param_list(tokens)?; + let (param_types, param_ids) = parse_param_list(tokens)?; expect_token(Token::CloseParanthesis, tokens.pop_front())?; let block = match &tokens[0] { Token::Semicolon => {@@ -590,10 +494,15 @@ None
} _ => Some(parse_block(tokens)?), }; + let func_type = Type::Function( + param_types.into_iter().map(Box::new).collect(), + Box::new(specifier.0), + ); Ok(Declaration::F(FunctionDeclaration::D( id, - param_list, + param_ids, block, + func_type, specifier.1, ))) }@@ -604,6 +513,7 @@ expect_token(Token::Semicolon, tokens.pop_front())?;
Ok(Declaration::V(VariableDeclaration::D( id, Some(expression), + specifier.0, specifier.1, ))) }@@ -612,6 +522,7 @@ expect_token(Token::Semicolon, tokens.pop_front())?;
Ok(Declaration::V(VariableDeclaration::D( id, None, + specifier.0, specifier.1, ))) }@@ -622,7 +533,7 @@
/* program ::= { <function-declaration> } */ pub fn parse_tokens(mut tokens: VecDeque<Token>) -> Result<Ast> { let mut result = Vec::new(); - while tokens.len() > 0 && is_specifier(&tokens[0]) { + while tokens.len() > 0 && lex::is_specifier(&tokens[0]) { result.push(parse_declaration(&mut tokens)?); }
@@ -1,20 +1,15 @@
use std::{ collections::HashMap, + iter::zip, sync::atomic::{AtomicUsize, Ordering}, }; -use crate::frontend::parse::{ - Ast, Block, BlockItem, Declaration, Expression, ForInit, FunctionDeclaration, Statement, - StorageClass, VariableDeclaration, +use crate::frontend::ast::{ + Ast, BinaryOp, Block, BlockItem, Const, Declaration, Expression, ForInit, FunctionDeclaration, + Statement, StorageClass, Type, UnaryOp, VariableDeclaration, }; use anyhow::{Result, bail}; -#[derive(Debug, PartialEq)] -pub enum Type { - Int, - Function(i32 /* args count */), -} - #[derive(Debug, PartialEq, Clone)] pub enum IdentifierAttributes { FunctionAttributes(bool, bool /* defined, global */),@@ -25,75 +20,305 @@
#[derive(Debug, PartialEq, Clone)] pub enum InitialValue { Tentative, - Initial(i32), + Initial(StaticInit), NoInit, } +#[derive(Debug, PartialEq, Clone)] +pub enum StaticInit { + IntInit(i32), + LongInit(i64), +} + +pub fn get_expression_type(expr: &Expression) -> Result<&Type> { + match expr { + Expression::Constant(_, Some(var_type)) => Ok(var_type), + Expression::Variable(_, Some(var_type)) => Ok(var_type), + Expression::Cast(_, _, Some(var_type)) => Ok(var_type), + Expression::Unary(_, _, Some(var_type)) => Ok(var_type), + Expression::Binary(_, _, _, Some(var_type)) => Ok(var_type), + Expression::Assignment(_, _, Some(var_type)) => Ok(var_type), + Expression::CompoundAssignment(_, _, _, Some(var_type)) => Ok(var_type), + Expression::PostIncr(_, Some(var_type)) => Ok(var_type), + Expression::PostDecr(_, Some(var_type)) => Ok(var_type), + Expression::Conditional(_, _, _, Some(var_type)) => Ok(var_type), + Expression::FunctionCall(_, _, Some(var_type)) => Ok(var_type), + _ => bail!("Expressions should have a type at this point"), + } +} + +fn set_expression_type(expr: Expression, new_type: Type) -> Result<Expression> { + match expr { + Expression::Constant(a, _) => Ok(Expression::Constant(a, Some(new_type))), + Expression::Variable(a, _) => Ok(Expression::Variable(a, Some(new_type))), + Expression::Cast(a, b, _) => Ok(Expression::Cast(a, b, Some(new_type))), + Expression::Unary(a, b, _) => Ok(Expression::Unary(a, b, Some(new_type))), + Expression::Binary(a, b, c, _) => Ok(Expression::Binary(a, b, c, Some(new_type))), + Expression::Assignment(a, b, _) => Ok(Expression::Assignment(a, b, Some(new_type))), + Expression::CompoundAssignment(a, b, c, _) => { + Ok(Expression::CompoundAssignment(a, b, c, Some(new_type))) + } + Expression::PostIncr(a, _) => Ok(Expression::PostIncr(a, Some(new_type))), + Expression::PostDecr(a, _) => Ok(Expression::PostDecr(a, Some(new_type))), + Expression::Conditional(a, b, c, _) => Ok(Expression::Conditional(a, b, c, Some(new_type))), + Expression::FunctionCall(a, b, _) => Ok(Expression::FunctionCall(a, b, Some(new_type))), + } +} + +fn get_common_type(type1: Type, type2: Type) -> Result<Type> { + if type1 == type2 { + return Ok(type1); + } else { + return Ok(Type::Long); + } +} + +fn convert_const_to_static_init(constant: &Const) -> Result<StaticInit> { + match constant { + Const::Int(x) => Ok(StaticInit::IntInit(*x)), + Const::Long(x) => Ok(StaticInit::LongInit(*x)), + } +} + +fn convert_const_to_type(constant: &Const) -> Result<Type> { + match constant { + Const::Int(_) => Ok(Type::Int), + Const::Long(_) => Ok(Type::Long), + } +} + +fn set_const_type(constant: Const, new_type: Type) -> Result<Const> { + match constant { + Const::Int(x) => match new_type { + Type::Int => Ok(constant), + Type::Long => Ok(Const::Long(x as i64)), + Type::Function(_, _) => bail!("Const to function is not allowed"), + }, + Const::Long(x) => match new_type { + Type::Int => { + let mut val = x; + while val > 2_i64.pow(31) - 1 { + val -= 2_i64.pow(32); + } + Ok(Const::Int(val as i32)) + } + Type::Long => Ok(constant), + Type::Function(_, _) => bail!("Const to function is not allowed"), + }, + } +} + +fn set_init_value_type(init: StaticInit, new_type: Type) -> Result<StaticInit> { + match init { + StaticInit::IntInit(x) => match new_type { + Type::Int => Ok(init), + Type::Long => Ok(StaticInit::LongInit(x as i64)), + Type::Function(_, _) => bail!("StaticInit to function is not allowed"), + }, + StaticInit::LongInit(x) => match new_type { + Type::Int => { + let mut val = x; + while val > 2_i64.pow(31) - 1 { + val -= 2_i64.pow(32); + } + Ok(StaticInit::IntInit(val as i32)) + } + Type::Long => Ok(init), + Type::Function(_, _) => bail!("StaticInit to function is not allowed"), + }, + } +} + fn typecheck_expression( expr: Expression, - hash_map: &mut HashMap<String, (Type, IdentifierAttributes)>, + hash_map: &HashMap<String, (Type, IdentifierAttributes)>, ) -> Result<Expression> { match expr { - Expression::Variable(x) => { - if Type::Int != hash_map.get(&x).unwrap().0 { - bail!("Function name used as a variable {}", x); + Expression::Variable(x, _) => { + if let Some(var) = hash_map.get(&x) { + if matches!(var.0, Type::Function(_, _)) { + bail!("Function name used as a variable {}", x); + } + return Ok(Expression::Variable(x, Some(var.0.clone()))); + } + bail!("Variable {} not found in symbol table", x); + } + Expression::Constant(x, _) => match x { + Const::Int(i) => Ok(Expression::Constant(Const::Int(i), Some(Type::Int))), + Const::Long(l) => Ok(Expression::Constant(Const::Long(l), Some(Type::Long))), + }, + Expression::Unary(unary_op, expression, _) => { + let expr = typecheck_expression(*expression, hash_map)?; + let expr_type = get_expression_type(&expr)?.clone(); + match unary_op { + UnaryOp::Not => Ok(Expression::Unary(unary_op, Box::new(expr), Some(Type::Int))), + _ => Ok(Expression::Unary(unary_op, Box::new(expr), Some(expr_type))), } - Ok(Expression::Variable(x)) } - Expression::Unary(unary_op, expression) => Ok(Expression::Unary( - unary_op, - Box::new(typecheck_expression(*expression, hash_map)?), - )), - Expression::Binary(binary_op, left, right) => Ok(Expression::Binary( - binary_op, - Box::new(typecheck_expression(*left, hash_map)?), - Box::new(typecheck_expression(*right, hash_map)?), - )), - Expression::Assignment(left, right) => Ok(Expression::Assignment( - Box::new(typecheck_expression(*left, hash_map)?), - Box::new(typecheck_expression(*right, hash_map)?), - )), - Expression::CompoundAssignment(binary_op, left, right) => { + Expression::Binary(binary_op, left, right, _) => { + let left = typecheck_expression(*left, hash_map)?; + let right = typecheck_expression(*right, hash_map)?; + if matches!(binary_op, BinaryOp::LAnd) || matches!(binary_op, BinaryOp::LOr) { + return Ok(Expression::Binary( + binary_op, + Box::new(left), + Box::new(right), + Some(Type::Int), + )); + } + if matches!(binary_op, BinaryOp::LShift) || matches!(binary_op, BinaryOp::RShift) { + let left_type = get_expression_type(&left)?.clone(); + return Ok(Expression::Binary( + binary_op, + Box::new(left), + Box::new(right), + Some(left_type), + )); + } + let common_type = get_common_type( + get_expression_type(&left)?.clone(), + get_expression_type(&right)?.clone(), + )?; + let left = Expression::Cast( + common_type.clone(), + Box::new(left), + Some(common_type.clone()), + ); + let right = Expression::Cast( + common_type.clone(), + Box::new(right), + Some(common_type.clone()), + ); + match binary_op { + BinaryOp::Addition + | BinaryOp::Subtraction + | BinaryOp::Multiplication + | BinaryOp::Division + | BinaryOp::Modulo + | BinaryOp::And + | BinaryOp::Or + | BinaryOp::Xor => Ok(Expression::Binary( + binary_op, + Box::new(left), + Box::new(right), + Some(common_type), + )), + _ => Ok(Expression::Binary( + binary_op, + Box::new(left), + Box::new(right), + Some(Type::Int), + )), + } + } + Expression::Assignment(left, right, _) => { + let left = typecheck_expression(*left, hash_map)?; + let right = typecheck_expression(*right, hash_map)?; + let left_type = get_expression_type(&left)?.clone(); + let right = + Expression::Cast(left_type.clone(), Box::new(right), Some(left_type.clone())); + Ok(Expression::Assignment( + Box::new(left), + Box::new(right), + Some(left_type), + )) + } + Expression::CompoundAssignment(binary_op, left, right, _) => { + let left = typecheck_expression(*left, hash_map)?; + let right = typecheck_expression(*right, hash_map)?; + let left_type = get_expression_type(&left)?.clone(); + + if binary_op == BinaryOp::LShift || binary_op == BinaryOp::RShift { + return Ok(Expression::CompoundAssignment( + binary_op, + Box::new(left), + Box::new(right), + Some(left_type), + )); + } + + let common_type = get_common_type( + get_expression_type(&left)?.clone(), + get_expression_type(&right)?.clone(), + )?; + let right = Expression::Cast( + common_type.clone(), + Box::new(right), + Some(common_type.clone()), + ); Ok(Expression::CompoundAssignment( binary_op, - Box::new(typecheck_expression(*left, hash_map)?), - Box::new(typecheck_expression(*right, hash_map)?), + Box::new(left), + Box::new(right), + Some(common_type), )) } - Expression::PostIncr(expr) => Ok(Expression::PostIncr(Box::new(typecheck_expression( - *expr, hash_map, - )?))), - Expression::PostDecr(expr) => Ok(Expression::PostDecr(Box::new(typecheck_expression( - *expr, hash_map, - )?))), - Expression::Conditional(left, middle, right) => Ok(Expression::Conditional( - Box::new(typecheck_expression(*left, hash_map)?), - Box::new(typecheck_expression(*middle, hash_map)?), - Box::new(typecheck_expression(*right, hash_map)?), - )), - Expression::FunctionCall(name, args) => match hash_map.get(&name).unwrap().0 { - Type::Int => bail!("Variable used as a function name"), - Type::Function(param_count) if param_count as usize == args.len() => { + Expression::PostIncr(expr, _) => { + let expr = typecheck_expression(*expr, hash_map)?; + let expr_type = get_expression_type(&expr)?.clone(); + Ok(Expression::PostIncr(Box::new(expr), Some(expr_type))) + } + Expression::PostDecr(expr, _) => { + let expr = typecheck_expression(*expr, hash_map)?; + let expr_type = get_expression_type(&expr)?.clone(); + Ok(Expression::PostDecr(Box::new(expr), Some(expr_type))) + } + + Expression::Conditional(cond, then, else_expr, _) => { + let cond = typecheck_expression(*cond, hash_map)?; + let then = typecheck_expression(*then, hash_map)?; + let else_expr = typecheck_expression(*else_expr, hash_map)?; + let common_type = get_common_type( + get_expression_type(&then)?.clone(), + get_expression_type(&else_expr)?.clone(), + )?; + let then = set_expression_type(then, common_type.clone())?; + let else_expr = set_expression_type(else_expr, common_type.clone())?; + Ok(Expression::Conditional( + Box::new(cond), + Box::new(then), + Box::new(else_expr), + Some(common_type), + )) + } + Expression::FunctionCall(name, args, _) => match hash_map.get(&name) { + Some((Type::Function(param_types, ret_type), _)) => { + if param_types.len() != args.len() { + bail!( + "Function {} called with {} arguments, but it has {}", + name, + args.len(), + param_types.len(), + ); + } let mut new_args = Vec::new(); - for arg in args { - new_args.push(Box::new(typecheck_expression(*arg, hash_map)?)); + for (arg, param_type) in zip(args, param_types) { + let typed_arg = Expression::Cast( + *param_type.clone(), + Box::new(typecheck_expression(*arg, hash_map)?), + Some(*param_type.clone()), + ); + new_args.push(Box::new(typed_arg)); } - Ok(Expression::FunctionCall(name, new_args)) + Ok(Expression::FunctionCall( + name, + new_args, + Some(*ret_type.clone()), + )) } - Type::Function(param_count) => bail!( - "Function {} called with {} arguments, but it has {}", - name, - args.len(), - param_count, - ), + _ => bail!("Variable used as a function name"), }, - c => Ok(c), + Expression::Cast(new_type, expr, _) => Ok(Expression::Cast( + new_type.clone(), + Box::new(typecheck_expression(*expr, hash_map)?), + Some(new_type), + )), } } fn typecheck_optional_expression( opt_expression: Option<Expression>, - hash_map: &mut HashMap<String, (Type, IdentifierAttributes)>, + hash_map: &HashMap<String, (Type, IdentifierAttributes)>, ) -> Result<Option<Expression>> { match opt_expression { Some(expression) => Ok(Some(typecheck_expression(expression, hash_map)?)),@@ -109,7 +334,7 @@ match for_init {
ForInit::D(declaration) => { let decl = typecheck_local_variable_declaration(declaration, hash_map)?; match &decl { - VariableDeclaration::D(_, _, storage_class) => { + VariableDeclaration::D(_, _, _, storage_class) => { if *storage_class != None { bail!("For loop headers cannot have a storage class"); }@@ -127,11 +352,25 @@
fn typecheck_statement( statement: Statement, hash_map: &mut HashMap<String, (Type, IdentifierAttributes)>, + function_name: &String, ) -> Result<Statement> { match statement { - Statement::Return(expression) => Ok(Statement::Return(typecheck_expression( - expression, hash_map, - )?)), + Statement::Return(expression) => { + let expr = typecheck_expression(expression, hash_map)?; + if let Some((fun_type, _)) = hash_map.get(function_name) { + match fun_type { + Type::Function(_, ret_type) => { + return Ok(Statement::Return(Expression::Cast( + *ret_type.clone(), + Box::new(expr), + Some(*ret_type.clone()), + ))); + } + _ => bail!("Should be function"), + } + } + Ok(Statement::Return(expr)) + } Statement::Expression(expression) => Ok(Statement::Expression(typecheck_expression( expression, hash_map, )?)),@@ -139,27 +378,31 @@ Statement::Null => Ok(Statement::Null),
Statement::If(condition, if_statement, else_statement) => match else_statement { Some(x) => Ok(Statement::If( typecheck_expression(condition, hash_map)?, - Box::new(typecheck_statement(*if_statement, hash_map)?), - Some(Box::new(typecheck_statement(*x, hash_map)?)), + Box::new(typecheck_statement(*if_statement, hash_map, function_name)?), + Some(Box::new(typecheck_statement(*x, hash_map, function_name)?)), )), None => Ok(Statement::If( typecheck_expression(condition, hash_map)?, - Box::new(typecheck_statement(*if_statement, hash_map)?), + Box::new(typecheck_statement(*if_statement, hash_map, function_name)?), None, )), }, Statement::Labeled(id, statement) => Ok(Statement::Labeled( id, - Box::new(typecheck_statement(*statement, hash_map)?), + Box::new(typecheck_statement(*statement, hash_map, function_name)?), )), - Statement::Compound(block) => Ok(Statement::Compound(typecheck_block(block, hash_map)?)), + Statement::Compound(block) => Ok(Statement::Compound(typecheck_block( + block, + hash_map, + function_name, + )?)), Statement::While(expression, statement, label) => Ok(Statement::While( typecheck_expression(expression, hash_map)?, - Box::new(typecheck_statement(*statement, hash_map)?), + Box::new(typecheck_statement(*statement, hash_map, function_name)?), label, )), Statement::DoWhile(statement, expression, label) => Ok(Statement::DoWhile( - Box::new(typecheck_statement(*statement, hash_map)?), + Box::new(typecheck_statement(*statement, hash_map, function_name)?), typecheck_expression(expression, hash_map)?, label, )),@@ -167,7 +410,7 @@ Statement::For(for_init, condition, step, body, label) => {
let for_init = typecheck_for_init(for_init, hash_map)?; let condition = typecheck_optional_expression(condition, hash_map)?; let step = typecheck_optional_expression(step, hash_map)?; - let body = typecheck_statement(*body, hash_map)?; + let body = typecheck_statement(*body, hash_map, function_name)?; Ok(Statement::For( for_init, condition,@@ -178,17 +421,38 @@ ))
} Statement::Case(expression, statement, label) => Ok(Statement::Case( expression, - Box::new(typecheck_statement(*statement, hash_map)?), + Box::new(typecheck_statement(*statement, hash_map, function_name)?), label, )), - Statement::Switch(expression, statement, label, cases) => Ok(Statement::Switch( - typecheck_expression(expression, hash_map)?, - Box::new(typecheck_statement(*statement, hash_map)?), - label, - cases, - )), + Statement::Switch(expression, statement, label, cases) => { + let expr = typecheck_expression(expression, hash_map)?; + let expr_type = get_expression_type(&expr)?; + + let mut new_cases = Vec::new(); + for (case, id) in &cases { + if let Some(constant) = case { + let new_case = set_const_type(constant.clone(), expr_type.clone())?; + for (test_case, _) in &new_cases { + if let Some(cc) = test_case { + if *cc == new_case { + bail!("Duplication of cases in switch statement: {:?}", new_cases); + } + } + } + new_cases.push((Some(new_case), id.clone())); + } else { + new_cases.push((None, id.clone())); + } + } + Ok(Statement::Switch( + expr, + Box::new(typecheck_statement(*statement, hash_map, function_name)?), + label, + new_cases, + )) + } Statement::Default(statement, label) => Ok(Statement::Default( - Box::new(typecheck_statement(*statement, hash_map)?), + Box::new(typecheck_statement(*statement, hash_map, function_name)?), label, )), c => Ok(c),@@ -200,9 +464,11 @@ decl: VariableDeclaration,
hash_map: &mut HashMap<String, (Type, IdentifierAttributes)>, ) -> Result<VariableDeclaration> { match decl { - VariableDeclaration::D(name, initializer, storage_class) => { - let mut init_value = match initializer { - Some(Expression::Constant(x)) => InitialValue::Initial(x), + VariableDeclaration::D(name, initializer, var_type, storage_class) => { + let mut init_value = match &initializer { + Some(Expression::Constant(x, _)) => { + InitialValue::Initial(convert_const_to_static_init(x)?) + } Some(_) => bail!("Non-constant initializer"), None => { if matches!(storage_class, Some(StorageClass::Extern)) {@@ -215,8 +481,8 @@ };
let mut global = !matches!(storage_class, Some(StorageClass::Static)); if let Some(old_decl) = hash_map.get(&name) { - if !matches!(old_decl.0, Type::Int) { - bail!("Function redeclaration as variable is not allowed"); + if old_decl.0 != var_type { + bail!("Type mismatching redeclaration of variables is not allowed"); } match &old_decl.1 {@@ -247,9 +513,26 @@ _ => bail!("No static attributes found for global variable declaration"),
}; } + let init_value = match init_value { + InitialValue::Tentative => init_value, + InitialValue::Initial(static_init) => { + InitialValue::Initial(set_init_value_type(static_init, var_type.clone())?) + } + InitialValue::NoInit => init_value, + }; + let initializer = match initializer { + Some(e) => Some(set_expression_type(e, var_type.clone())?), + None => None, + }; + let attrs = IdentifierAttributes::StaticAttributes(init_value, global); - hash_map.insert(name.clone(), (Type::Int, attrs)); - Ok(VariableDeclaration::D(name, initializer, storage_class)) + hash_map.insert(name.clone(), (var_type.clone(), attrs)); + Ok(VariableDeclaration::D( + name, + initializer, + var_type, + storage_class, + )) } } }@@ -259,7 +542,7 @@ decl: VariableDeclaration,
hash_map: &mut HashMap<String, (Type, IdentifierAttributes)>, ) -> Result<VariableDeclaration> { match decl { - VariableDeclaration::D(name, opt_expression, storage_class) => { + VariableDeclaration::D(name, opt_expression, var_type, storage_class) => { let mut new_init = opt_expression; match storage_class { Some(StorageClass::Extern) => {@@ -267,44 +550,60 @@ if new_init != None {
bail!("Initializer on local extern variable not allowed"); } if let Some(old_decl) = hash_map.get(&name) { - if !matches!(old_decl.0, Type::Int) { - bail!("Function redeclared as variable"); + if old_decl.0 != var_type { + bail!("Type mismatch in local variable declaration"); } } else { hash_map.insert( name.clone(), ( - Type::Int, + var_type.clone(), IdentifierAttributes::StaticAttributes(InitialValue::NoInit, true), ), ); } } Some(StorageClass::Static) => { - let init_value = match new_init { - Some(Expression::Constant(x)) => InitialValue::Initial(x), + let init = match &new_init { + Some(Expression::Constant(x, _)) => convert_const_to_static_init(x)?, Some(_) => { bail!("Non-constant initializer on local static variable not allowed") } - None => InitialValue::Initial(0), + None => StaticInit::IntInit(0), }; hash_map.insert( name.clone(), ( - Type::Int, - IdentifierAttributes::StaticAttributes(init_value, false), + var_type.clone(), + IdentifierAttributes::StaticAttributes( + InitialValue::Initial(set_init_value_type(init, var_type.clone())?), + false, + ), ), ); } None => { - hash_map.insert(name.clone(), (Type::Int, IdentifierAttributes::LocalAttr)); + hash_map.insert( + name.clone(), + (var_type.clone(), IdentifierAttributes::LocalAttr), + ); if let Some(init) = new_init { new_init = Some(typecheck_expression(init, hash_map)?); } } }; - Ok(VariableDeclaration::D(name, new_init, storage_class)) + let new_init = match new_init { + Some(e) => Some(set_expression_type(e, var_type.clone())?), + None => None, + }; + + Ok(VariableDeclaration::D( + name, + new_init, + var_type, + storage_class, + )) } } }@@ -314,14 +613,13 @@ decl: FunctionDeclaration,
hash_map: &mut HashMap<String, (Type, IdentifierAttributes)>, ) -> Result<FunctionDeclaration> { match decl { - FunctionDeclaration::D(name, params, block, storage_class) => { - let function_type = Type::Function(params.len() as i32); + FunctionDeclaration::D(name, params, block, fun_type, storage_class) => { let has_body = block != None; let mut defined = false; let mut global = !matches!(storage_class, Some(StorageClass::Static)); if let Some(old_decl) = hash_map.get(&name) { - if function_type != old_decl.0 { + if fun_type != old_decl.0 { bail!("Incompatible function declarations"); }@@ -341,20 +639,26 @@ };
} let attr = IdentifierAttributes::FunctionAttributes(has_body | defined, global); - hash_map.insert(name.clone(), (function_type, attr.clone())); + hash_map.insert(name.clone(), (fun_type.clone(), attr.clone())); let mut new_block = None; - if let Some(body) = block { - for param in ¶ms { - hash_map.insert(param.clone(), (Type::Int, attr.clone())); + match &fun_type { + Type::Function(items, _) => { + if let Some(body) = block { + for (param, item) in zip(¶ms, items) { + hash_map.insert(param.clone(), (*item.clone(), attr.clone())); + } + new_block = Some(typecheck_block(body, hash_map, &name)?); + } } - new_block = Some(typecheck_block(body, hash_map)?); - } + _ => bail!("Fun type should be function"), + }; Ok(FunctionDeclaration::D( name, params, new_block, + fun_type, storage_class, )) }@@ -394,9 +698,14 @@
fn typecheck_block_item( item: BlockItem, hash_map: &mut HashMap<String, (Type, IdentifierAttributes)>, + function_name: &String, ) -> Result<BlockItem> { match item { - BlockItem::S(statement) => Ok(BlockItem::S(typecheck_statement(statement, hash_map)?)), + BlockItem::S(statement) => Ok(BlockItem::S(typecheck_statement( + statement, + hash_map, + function_name, + )?)), BlockItem::D(declaration) => Ok(BlockItem::D(typecheck_local_declaration( declaration, hash_map,@@ -407,12 +716,13 @@
fn typecheck_block( block: Block, hash_map: &mut HashMap<String, (Type, IdentifierAttributes)>, + function_name: &String, ) -> Result<Block> { match block { Block::B(block_items) => { let mut new_block_items = Vec::new(); for item in block_items { - new_block_items.push(typecheck_block_item(item, hash_map)?); + new_block_items.push(typecheck_block_item(item, hash_map, function_name)?); } Ok(Block::B(new_block_items)) }@@ -427,12 +737,12 @@ /* HashMap<name, (type, attributes)> */
let mut hash_map: HashMap<String, (Type, IdentifierAttributes)> = HashMap::new(); match ast { - Ast::Program(functions) => { - let mut funcs = Vec::new(); - for func in functions { - funcs.push(typecheck_file_scope_declaration(func, &mut hash_map)?); + Ast::Program(declarations) => { + let mut new_decls = Vec::new(); + for decl in declarations { + new_decls.push(typecheck_file_scope_declaration(decl, &mut hash_map)?); } - Ok((Ast::Program(funcs), hash_map)) + Ok((Ast::Program(new_decls), hash_map)) } } }