类型推导(Type Inference)是现代高级语言中一个越来越常见的特性。其实,这个特性在函数式语言
在阅读类型推导相关文章的时候,看到了另外一个哥们儿根据“Basic Type Checking”那篇文章实现的一
个简单的HindleyMilner推导器,对它稍作了修改,代码附下。(PS:这个哥们儿在注释的最后一句说了,“Do with it what you will”,我的修改版也保留了。)我的修改主要集中在两个方面:
1. 可读性改进:
a. 在注释中追加语法定义;
b. 更易理解的名字;
c. 使用Identifer代替裸字符串;
d. 使用Constant代替字面量;
2. 在语法级别支持IF表达式:
- /*
* Andrew Forrest
* Implementation of basic polymorphic type-checking for a simple language.
* Based heavily on Nikita Borisov’s Perl implementation at
* ~nikitab/courses/cs263/hm.html
* which in turn is based on the paper by Luca Cardelli at
* If you run it with "scala HindleyMilner.scala" it will attempt to report the types
* for a few example expressions. (It uses UTF-8 for output, so you may need to set your
* terminal accordingly.)
* Changes
* June 30, 2011 by Liang Kun(liangkun(AT)baidu.com)
* 1. Modify to enhance readability
* 2. Extend to Support if expression in syntax
* Do with it what you will. :)
/** Syntax definition. This is a simple lambda calculous syntax.
* Expression ::= Identifier
* | Constant
* | "if" Expression "then" Expression "else" Expression
* | "lambda(" Identifier ") " Expression
* | Expression "(" Expression ")"
* | "let" Identifier "=" Expression "in" Expression
* | "letrec" Identifier "=" Expression "in" Expression
* | "(" Expression ")"
* See the examples below in main function.
sealed abstract class Expression
case class Identifier(name: String) extends Expression {
override def toString = name
case class Constant(value: String) extends Expression {
override def toString = value
case class If(condition: Expression, then: Expression, other: Expression) extends Expression {
override def toString = "(if " + condition + " then " + then + " else " + other + ")"
case class Lambda(argument: Identifier, body: Expression) extends Expression {
override def toString = "(lambda " + argument + " → " + body + ")"
case class Apply(function: Expression, argument: Expression) extends Expression {
override def toString = "(" + function + " " + argument + ")"
case class Let(binding: Identifier, definition: Expression, body: Expression) extends Expression {
override def toString = "(let " + binding + " = " + definition + " in " + body + ")"
case class Letrec(binding: Identifier, definition: Expression, body: Expression) extends Expression {
override def toString = "(letrec " + binding + " = " + definition + " in " + body + ")"
/** Exceptions may happened */
class TypeError(msg: String) extends Exception(msg)
class ParseError(msg: String) extends Exception(msg)
/** Type inference system */
object TypeSystem {
type Env = Map[Identifier, Type]
val EmptyEnv: Map[Identifier, Type] = Map.empty
// type variable and type operator
sealed abstract class Type
case class Variable(id: Int) extends Type {
var instance: Option[Type] = None
lazy val name = nextUniqueName()
override def toString = instance match {
case Some(t) => t.toString
case None => name
case class Operator(name: String, args: Seq[Type]) extends Type {
override def toString = {
if (args.length == 0)
else if (args.length == 2)
"[" + args(0) + " " + name + " " + args(1) + "]"
args.mkString(name + "[", ", ", "]")
// builtin types, types can be extended by environment
def Function(from: Type, to: Type) = Operator("→", Array(from, to))
val Integer = Operator("Integer", Array[Type]())
val Boolean = Operator("Boolean", Array[Type]())
protected var _nextVariableName = 'α';
protected def nextUniqueName() = {
val result = _nextVariableName
_nextVariableName = (_nextVariableName.toInt + 1).toChar
protected var _nextVariableId = 0
def newVariable(): Variable = {
val result = _nextVariableId
_nextVariableId += 1
// main entry point
def analyze(expr: Expression, env: Env): Type = analyze(expr, env, Set.empty)
def analyze(expr: Expression, env: Env, nongeneric: Set[Variable]): Type = expr match {
case i: Identifier => getIdentifierType(i, env, nongeneric)
case Constant(value) => getConstantType(value)
case If(cond, then, other) => {
val condType = analyze(cond, env, nongeneric)
val thenType = analyze(then, env, nongeneric)
val otherType = analyze(other, env, nongeneric)
unify(condType, Boolean)
unify(thenType, otherType)
case Apply(func, arg) => {
val funcType = analyze(func, env, nongeneric)
val argType = analyze(arg, env, nongeneric)
val resultType = newVariable()
unify(Function(argType, resultType), funcType)
case Lambda(arg, body) => {
val argType = newVariable()
val resultType = analyze(body,
env + (arg -> argType),
nongeneric + argType)
Function(argType, resultType)
case Let(binding, definition, body) => {
val definitionType = analyze(definition, env, nongeneric)
val newEnv = env + (binding -> definitionType)
analyze(body, newEnv, nongeneric)
case Letrec(binding, definition, body) => {
val newType = newVariable()
val newEnv = env + (binding -> newType)
val definitionType = analyze(definition, newEnv, nongeneric + newType)
unify(newType, definitionType)
analyze(body, newEnv, nongeneric)
protected def getIdentifierType(id: Identifier, env: Env, nongeneric: Set[Variable]): Type = {
if (env.contains(id))
fresh(env(id), nongeneric)
throw new ParseError("Undefined symbol: " + id)
protected def getConstantType(value: String): Type = {
throw new ParseError("Undefined symbol: " + value)
protected def fresh(t: Type, nongeneric: Set[Variable]) = {
import scala.collection.mutable
val mappings = new mutable.HashMap[Variable, Variable]
def freshrec(tp: Type): Type = {
prune(tp) match {
case v: Variable =>
if (isgeneric(v, nongeneric))
mappings.getOrElseUpdate(v, newVariable())
case Operator(name, args) =>
Operator(name, args.map(freshrec(_)))
protected def unify(t1: Type, t2: Type) {
val type1 = prune(t1)
val type2 = prune(t2)
(type1, type2) match {
case (a: Variable, b) => if (a != b) {
if (occursintype(a, b))
throw new TypeError("Recursive unification")
a.instance = Some(b)
case (a: Operator, b: Variable) => unify(b, a)
case (a: Operator, b: Operator) => {
if (a.name != b.name ||
a.args.length != b.args.length) throw new TypeError("Type mismatch: " + a + " ≠ " + b)
for(i <- 0 until a.args.length)
unify(a.args(i), b.args(i))
// Returns the currently defining instance of t.
// As a side effect, collapses the list of type instances.
protected def prune(t: Type): Type = t match {
case v: Variable if v.instance.isDefined => {
val inst = prune(v.instance.get)
v.instance = Some(inst)
case _ => t
// Note: must be called with v 'pre-pruned'
protected def isgeneric(v: Variable, nongeneric: Set[Variable]) = !(occursin(v, nongeneric))
// Note: must be called with v 'pre-pruned'
protected def occursintype(v: Variable, type2: Type): Boolean = {
prune(type2) match {
case `v` => true
case Operator(name, args) => occursin(v, args)
case _ => false
protected def occursin(t: Variable, list: Iterable[Type]) =
list exists (t2 => occursintype(t, t2))
protected val checkDigits = "^(\\d+)$".r
protected def isIntegerLiteral(name: String) = checkDigits.findFirstIn(name).isDefined
/** Demo program */
object HindleyMilner {
def main(args: Array[String]){
Console.setOut(new java.io.PrintStream(Console.out, true, "utf-8"))
// extends the system with a new type[pair] and some builtin functions
val left = TypeSystem.newVariable()
val right = TypeSystem.newVariable()
val pairType = TypeSystem.Operator("×", Array(left, right))
val myenv: TypeSystem.Env = TypeSystem.EmptyEnv ++ Array(
Identifier("pair") -> TypeSystem.Function(left, TypeSystem.Function(right, pairType)),
Identifier("true") -> TypeSystem.Boolean,
Identifier("false")-> TypeSystem.Boolean,
Identifier("zero") -> TypeSystem.Function(TypeSystem.Integer, TypeSystem.Boolean),
Identifier("pred") -> TypeSystem.Function(TypeSystem.Integer, TypeSystem.Integer),
Identifier("times")-> TypeSystem.Function(TypeSystem.Integer,
TypeSystem.Function(TypeSystem.Integer, TypeSystem.Integer))
// example expressions
val pair = Apply(
Identifier("pair"), Apply(Identifier("f"), Constant("4"))
Apply(Identifier("f"), Identifier("true"))
val examples = Array[Expression](
// factorial
Letrec(Identifier("factorial"), // letrec factorial =
Lambda(Identifier("n"), // lambda n =>
Apply(Identifier("zero"), Identifier("n")),
Apply(Identifier("times"), Identifier("n")),
Apply(Identifier("pred"), Identifier("n"))
), // in
Apply(Identifier("factorial"), Constant("5"))
// Should fail:
// fn x => (pair(x(3) (x(true))))
Apply(Identifier("x"), Constant("3"))
Apply(Identifier("x"), Identifier("true"))
// pair(f(3), f(true))
Apply(Identifier("pair"), Apply(Identifier("f"), Constant("4"))),
Apply(Identifier("f"), Identifier("true"))
// letrec f = (fn x => x) in ((pair (f 4)) (f true))
Let(Identifier("f"), Lambda(Identifier("x"), Identifier("x")), pair),
// Should fail:
// fn f => f f
Lambda(Identifier("f"), Apply(Identifier("f"), Identifier("f"))),
// let g = fn f => 5 in g g
Lambda(Identifier("f"), Constant("5")),
Apply(Identifier("g"), Identifier("g"))
// example that demonstrates generic and non-generic variables:
// fn g => let f = fn x => g in pair (f 3, f true)
Lambda(Identifier("x"), Identifier("g")),
Apply(Identifier("f"), Constant("3"))
Apply(Identifier("f"), Identifier("true"))
// Function composition
// fn f (fn g (fn arg (f g arg)))
Lambda( Identifier("f"),
Lambda( Identifier("g"),
Lambda( Identifier("arg"),
Apply(Identifier("g"), Apply(Identifier("f"), Identifier("arg")))
for(eg <- examples){
tryexp(myenv, eg)
def tryexp(env: TypeSystem.Env, expr: Expression) {
try {
val t = TypeSystem.analyze(expr, env)
case t: ParseError => print(t.getMessage)
case t: TypeError => print(t.getMessage)
println(":\t" + expr)
阅读(2949) | 评论(0) | 转发(0) |