Untitled

 avatar
unknown
plain_text
5 months ago
6.1 kB
7
Indexable
package wacc

import ast.*
import TypeConverter._
import scala.collection.mutable

type FunctionTable = Map[String, FuncType]

object semanticAnalyser {
  def analyse(
      program: Program
  ): Either[List[SemanticError], (SymbolTable, FunctionTable)] = {
    given ctx: SemanticContext = SemanticContext(
      SymbolTable(),
      mutable.Map(),
      List.newBuilder
    )

    // Collect all function signatures first to allow for forward declarations
    // We also error here if function name is reused
    program.functions.foreach(func => {
      if (ctx.functionTable.contains(func.name)) {
        ctx.error(FunctionRedefinitionError(func))
        Left(ctx.errorCollector.result())
      } else {
        ctx.functionTable += (func.name -> FuncType(
          func.returnType.toSemType,
          func.parameters.map(_._1.toSemType)
        ))
        Right(ctx.functionTable(func.name))
      }
    })

    // Analyse each function's paramters and body.
    program.functions.foreach(analyseFunction)

    // Analyse the main program body
    analyseBody(program.body)

    val errors = ctx.errorCollector.result()
    if (errors.nonEmpty) {
      Left(errors)
    } else {
      Right((ctx.symbolTable, ctx.functionTable.toMap))
    }
  }

  def analyseFunction(
      func: Func
  )(using ctx: SemanticContext): Unit = {
    // check for redeclared parameters
    val paramNames = func.parameters.map(_._2)
    val redeclaredParams = paramNames.diff(paramNames.distinct)
    redeclaredParams.foreach { name =>
      ctx.error(VariableRedeclarationError(name))
    }

    // Creates a new nested scope for the function body
    val funcCtx = ctx.copy(symbolTable = ctx.symbolTable.createChild())
    func.parameters.foreach {
      case (typ, name) => {
        if (ctx.symbolTable.lookupLocal(name).isDefined) {
          ctx.error(VariableRedeclarationError(name))
        } else {
          ctx.symbolTable.add(name, typ.toSemType)
        }
      }
    }

    // Check that the returning block inn the function body returns a type
    // that matches the syntactic type of the function
    val actualRetType = getFunctionReturnType(func, func.body)
    if (actualRetType != func.returnType.toSemType) {
      ctx.error(
        ReturnTypeMismatch(func.returnType.toSemType, actualRetType, func)
      )
    }

    // Analyse the function body using the newly created context
    analyseBody(func.body)(using funcCtx)
  }

  /** Pre: body is a list of statements within the function func Assumes that
    * the function is syntactically correct and so the last statement of the
    * function must be a returning block.
    *
    * @param func
    *   The function that the body belongs to
    * @param body
    *   A list of statements within the function. This could be the entire
    *   function body or a body of an if statement within the returning block of
    *   the function.
    * @return
    */
  def getFunctionReturnType(
      func: Func,
      body: List[Stmt]
  )(using ctx: SemanticContext): SemType = body.last match {
    case Return(expr) => analyseExpr(expr)
    case Exit(expr)   => analyseExpr(expr)
    case If(cond, thenBody, elseBody) => {
      val thenRet = getFunctionReturnType(func, thenBody)
      val elseRet = getFunctionReturnType(func, elseBody)

      if (thenRet != elseRet) {
        thenRet
      } else {
        ctx.error(IfReturnTypeMismatchError(thenRet, elseRet, func))
        SemAny
      }
    }
    case _ =>
      throw new Exception(
        "The last statement of a function must be a returning block. This is a syntax error."
      )
  }

  def analyseBody(
      stmts: List[Stmt]
  )(using ctx: SemanticContext): Unit = {
    stmts.foreach(analyseStmt)
  }

  def analyseStmt(
      statement: Stmt
  )(using ctx: SemanticContext): Unit = statement match {
    case Skip => // No checking needed
    case Scope(body) =>
      val newCtx = ctx.copy(symbolTable = ctx.symbolTable.createChild())
      analyseBody(body)(using newCtx)

    case stmt @ Declaration(varType, name, value) =>
      if (ctx.symbolTable.lookupLocal(name).isDefined) {
        ctx.error(VariableRedeclarationError(name))
      } else {
        ctx.symbolTable.add(name, varType.toSemType)
      }

      val valueType = analyseRValue(value)
      if (valueType != varType.toSemType) {
        ctx.error(DeclarationTypeError(stmt, valueType))
      }

    case Exit(expr) =>
      val typ = analyseExpr(expr)
      if (typ != SemInt)
        ctx.error(ExitTypeError(typ))

    case Free(expr) =>
    // val typ = analyseExpr(expr)
    // if (!typ.isInstanceOf[FreeableType]) {
    //   ctx.error(FreeTypeError(typ))
    // }

    case If(condition, thenBody, elseBody) =>
      val condTyp = analyseExpr(condition)
      if (condTyp != SemBool) {
        ctx.error(ConditionTypeError(statement, condTyp))
      }

      val thenCtx = ctx.copy(symbolTable = ctx.symbolTable.createChild())
      analyseBody(thenBody)(using thenCtx)

      val elseCtx = ctx.copy(symbolTable = ctx.symbolTable.createChild())
      analyseBody(elseBody)(using elseCtx)

    case Print(expr)   => analyseExpr(expr)
    case PrintLn(expr) => analyseExpr(expr)
    case Read(lvalue)  =>
    // val typ = analyseLValue(lvalue)
    // if (!typ.isInstanceOf[ReadableType])
    //   ctx.error(ReadTypeError(statement, typ))

    case While(condition, body) =>
      val condType = analyseExpr(condition)
      if (condType != SemBool)
        ctx.error(ConditionTypeError(statement, condType))

    case Assignment(_, _) => ???
    case Return(_)        => // TODO
  }

  def analyseExpr(expression: Expr)(using ctx: SemanticContext): SemType =
    SemInt

  def analyseLValue(lvalue: LValue)(using ctx: SemanticContext): SemType =
    SemInt

  def analyseRValue(rvalue: RValue)(using ctx: SemanticContext): SemType =
    SemInt

  case class SemanticContext(
      symbolTable: SymbolTable,
      functionTable: mutable.Map[String, FuncType],
      errorCollector: mutable.Builder[SemanticError, List[SemanticError]]
  ) {
    def error(error: SemanticError): Unit = errorCollector += error
  }
}
Editor is loading...
Leave a Comment