类型推导(Type Inference)是现代高级语言中一个越来越常见的特性。其实,这个特性在函数式语言
中早有了广泛应用。而HindleyMilner推导器是所有类型推导器的基础。
在阅读类型推导相关文章的时候,看到了另外一个哥们儿根据“Basic Type Checking”那篇文章实现的一
个简单的HindleyMilner推导器,对它稍作了修改,代码附下。(PS:这个哥们儿在注释的最后一句说了,“Do with it what you will”,我的修改版也保留了。)我的修改主要集中在两个方面:
1. 可读性改进:
a. 在注释中追加语法定义;
b. 更易理解的名字;
c. 使用Identifer代替裸字符串;
d. 使用Constant代替字面量;
2. 在语法级别支持IF表达式:
在原实现中,条件表达式是作为内置函数实现的,而论文中是在语法级别实现的。我扩展了原作者
的语法,支持IF表达式。
- /*
-
*
-
* Andrew Forrest
-
*
-
* Implementation of basic polymorphic type-checking for a simple language.
-
* Based heavily on Nikita Borisov’s Perl implementation at
-
* ~nikitab/courses/cs263/hm.html
-
* which in turn is based on the paper by Luca Cardelli at
-
*
-
*
-
* If you run it with "scala HindleyMilner.scala" it will attempt to report the types
-
* for a few example expressions. (It uses UTF-8 for output, so you may need to set your
-
* terminal accordingly.)
-
*
-
* Changes
-
* June 30, 2011 by Liang Kun(liangkun(AT)baidu.com)
-
* 1. Modify to enhance readability
-
* 2. Extend to Support if expression in syntax
-
*
-
*
-
*
-
* Do with it what you will. :)
-
*/
-
-
/** Syntax definition. This is a simple lambda calculous syntax.
-
* Expression ::= Identifier
-
* | Constant
-
* | "if" Expression "then" Expression "else" Expression
-
* | "lambda(" Identifier ") " Expression
-
* | Expression "(" Expression ")"
-
* | "let" Identifier "=" Expression "in" Expression
-
* | "letrec" Identifier "=" Expression "in" Expression
-
* | "(" Expression ")"
-
* See the examples below in main function.
-
*/
-
sealed abstract class Expression
-
-
case class Identifier(name: String) extends Expression {
-
override def toString = name
-
}
-
-
case class Constant(value: String) extends Expression {
-
override def toString = value
-
}
-
-
case class If(condition: Expression, then: Expression, other: Expression) extends Expression {
-
override def toString = "(if " + condition + " then " + then + " else " + other + ")"
-
}
-
-
case class Lambda(argument: Identifier, body: Expression) extends Expression {
-
override def toString = "(lambda " + argument + " → " + body + ")"
-
}
-
-
case class Apply(function: Expression, argument: Expression) extends Expression {
-
override def toString = "(" + function + " " + argument + ")"
-
}
-
-
case class Let(binding: Identifier, definition: Expression, body: Expression) extends Expression {
-
override def toString = "(let " + binding + " = " + definition + " in " + body + ")"
-
}
-
-
case class Letrec(binding: Identifier, definition: Expression, body: Expression) extends Expression {
-
override def toString = "(letrec " + binding + " = " + definition + " in " + body + ")"
-
}
-
-
-
/** Exceptions may happened */
-
class TypeError(msg: String) extends Exception(msg)
-
class ParseError(msg: String) extends Exception(msg)
-
-
-
/** Type inference system */
-
object TypeSystem {
-
type Env = Map[Identifier, Type]
-
val EmptyEnv: Map[Identifier, Type] = Map.empty
-
-
// type variable and type operator
-
sealed abstract class Type
-
case class Variable(id: Int) extends Type {
-
var instance: Option[Type] = None
-
lazy val name = nextUniqueName()
-
-
override def toString = instance match {
-
case Some(t) => t.toString
-
case None => name
-
}
-
}
-
-
case class Operator(name: String, args: Seq[Type]) extends Type {
-
override def toString = {
-
if (args.length == 0)
-
name
-
else if (args.length == 2)
-
"[" + args(0) + " " + name + " " + args(1) + "]"
-
else
-
args.mkString(name + "[", ", ", "]")
-
}
-
}
-
-
// builtin types, types can be extended by environment
-
def Function(from: Type, to: Type) = Operator("→", Array(from, to))
-
val Integer = Operator("Integer", Array[Type]())
-
val Boolean = Operator("Boolean", Array[Type]())
-
-
-
protected var _nextVariableName = 'α';
-
protected def nextUniqueName() = {
-
val result = _nextVariableName
-
_nextVariableName = (_nextVariableName.toInt + 1).toChar
-
result.toString
-
}
-
protected var _nextVariableId = 0
-
def newVariable(): Variable = {
-
val result = _nextVariableId
-
_nextVariableId += 1
-
Variable(result)
-
}
-
-
-
// main entry point
-
def analyze(expr: Expression, env: Env): Type = analyze(expr, env, Set.empty)
-
def analyze(expr: Expression, env: Env, nongeneric: Set[Variable]): Type = expr match {
-
case i: Identifier => getIdentifierType(i, env, nongeneric)
-
-
case Constant(value) => getConstantType(value)
-
-
case If(cond, then, other) => {
-
val condType = analyze(cond, env, nongeneric)
-
val thenType = analyze(then, env, nongeneric)
-
val otherType = analyze(other, env, nongeneric)
-
unify(condType, Boolean)
-
unify(thenType, otherType)
-
thenType
-
}
-
-
case Apply(func, arg) => {
-
val funcType = analyze(func, env, nongeneric)
-
val argType = analyze(arg, env, nongeneric)
-
val resultType = newVariable()
-
unify(Function(argType, resultType), funcType)
-
resultType
-
}
-
-
case Lambda(arg, body) => {
-
val argType = newVariable()
-
val resultType = analyze(body,
-
env + (arg -> argType),
-
nongeneric + argType)
-
Function(argType, resultType)
-
}
-
-
case Let(binding, definition, body) => {
-
val definitionType = analyze(definition, env, nongeneric)
-
val newEnv = env + (binding -> definitionType)
-
analyze(body, newEnv, nongeneric)
-
}
-
-
case Letrec(binding, definition, body) => {
-
val newType = newVariable()
-
val newEnv = env + (binding -> newType)
-
val definitionType = analyze(definition, newEnv, nongeneric + newType)
-
unify(newType, definitionType)
-
analyze(body, newEnv, nongeneric)
-
}
-
}
-
-
protected def getIdentifierType(id: Identifier, env: Env, nongeneric: Set[Variable]): Type = {
-
if (env.contains(id))
-
fresh(env(id), nongeneric)
-
else
-
throw new ParseError("Undefined symbol: " + id)
-
}
-
-
protected def getConstantType(value: String): Type = {
-
if(isIntegerLiteral(value))
-
Integer
-
else
-
throw new ParseError("Undefined symbol: " + value)
-
}
-
-
protected def fresh(t: Type, nongeneric: Set[Variable]) = {
-
import scala.collection.mutable
-
val mappings = new mutable.HashMap[Variable, Variable]
-
def freshrec(tp: Type): Type = {
-
prune(tp) match {
-
case v: Variable =>
-
if (isgeneric(v, nongeneric))
-
mappings.getOrElseUpdate(v, newVariable())
-
else
-
v
-
-
case Operator(name, args) =>
-
Operator(name, args.map(freshrec(_)))
-
}
-
}
-
-
freshrec(t)
-
}
-
-
protected def unify(t1: Type, t2: Type) {
-
val type1 = prune(t1)
-
val type2 = prune(t2)
-
(type1, type2) match {
-
case (a: Variable, b) => if (a != b) {
-
if (occursintype(a, b))
-
throw new TypeError("Recursive unification")
-
a.instance = Some(b)
-
}
-
case (a: Operator, b: Variable) => unify(b, a)
-
case (a: Operator, b: Operator) => {
-
if (a.name != b.name ||
-
a.args.length != b.args.length) throw new TypeError("Type mismatch: " + a + " ≠ " + b)
-
-
for(i <- 0 until a.args.length)
-
unify(a.args(i), b.args(i))
-
}
-
}
-
}
-
-
// Returns the currently defining instance of t.
-
// As a side effect, collapses the list of type instances.
-
protected def prune(t: Type): Type = t match {
-
case v: Variable if v.instance.isDefined => {
-
val inst = prune(v.instance.get)
-
v.instance = Some(inst)
-
inst
-
}
-
case _ => t
-
}
-
-
// Note: must be called with v 'pre-pruned'
-
protected def isgeneric(v: Variable, nongeneric: Set[Variable]) = !(occursin(v, nongeneric))
-
-
// Note: must be called with v 'pre-pruned'
-
protected def occursintype(v: Variable, type2: Type): Boolean = {
-
prune(type2) match {
-
case `v` => true
-
case Operator(name, args) => occursin(v, args)
-
case _ => false
-
}
-
}
-
-
protected def occursin(t: Variable, list: Iterable[Type]) =
-
list exists (t2 => occursintype(t, t2))
-
-
protected val checkDigits = "^(\\d+)$".r
-
protected def isIntegerLiteral(name: String) = checkDigits.findFirstIn(name).isDefined
-
}
-
-
-
/** Demo program */
-
object HindleyMilner {
-
def main(args: Array[String]){
-
Console.setOut(new java.io.PrintStream(Console.out, true, "utf-8"))
-
-
// extends the system with a new type[pair] and some builtin functions
-
val left = TypeSystem.newVariable()
-
val right = TypeSystem.newVariable()
-
val pairType = TypeSystem.Operator("×", Array(left, right))
-
-
val myenv: TypeSystem.Env = TypeSystem.EmptyEnv ++ Array(
-
Identifier("pair") -> TypeSystem.Function(left, TypeSystem.Function(right, pairType)),
-
Identifier("true") -> TypeSystem.Boolean,
-
Identifier("false")-> TypeSystem.Boolean,
-
Identifier("zero") -> TypeSystem.Function(TypeSystem.Integer, TypeSystem.Boolean),
-
Identifier("pred") -> TypeSystem.Function(TypeSystem.Integer, TypeSystem.Integer),
-
Identifier("times")-> TypeSystem.Function(TypeSystem.Integer,
-
TypeSystem.Function(TypeSystem.Integer, TypeSystem.Integer))
-
)
-
-
// example expressions
-
val pair = Apply(
-
Apply(
-
Identifier("pair"), Apply(Identifier("f"), Constant("4"))
-
),
-
Apply(Identifier("f"), Identifier("true"))
-
)
-
val examples = Array[Expression](
-
// factorial
-
Letrec(Identifier("factorial"), // letrec factorial =
-
Lambda(Identifier("n"), // lambda n =>
-
If(
-
Apply(Identifier("zero"), Identifier("n")),
-
-
Constant("1"),
-
-
Apply(
-
Apply(Identifier("times"), Identifier("n")),
-
Apply(
-
Identifier("factorial"),
-
Apply(Identifier("pred"), Identifier("n"))
-
)
-
)
-
)
-
), // in
-
Apply(Identifier("factorial"), Constant("5"))
-
),
-
-
// Should fail:
-
// fn x => (pair(x(3) (x(true))))
-
Lambda(Identifier("x"),
-
Apply(
-
Apply(Identifier("pair"),
-
Apply(Identifier("x"), Constant("3"))
-
),
-
Apply(Identifier("x"), Identifier("true"))
-
)
-
),
-
-
// pair(f(3), f(true))
-
Apply(
-
Apply(Identifier("pair"), Apply(Identifier("f"), Constant("4"))),
-
Apply(Identifier("f"), Identifier("true"))
-
),
-
-
-
// letrec f = (fn x => x) in ((pair (f 4)) (f true))
-
Let(Identifier("f"), Lambda(Identifier("x"), Identifier("x")), pair),
-
-
// Should fail:
-
// fn f => f f
-
Lambda(Identifier("f"), Apply(Identifier("f"), Identifier("f"))),
-
-
// let g = fn f => 5 in g g
-
Let(
-
Identifier("g"),
-
Lambda(Identifier("f"), Constant("5")),
-
Apply(Identifier("g"), Identifier("g"))
-
),
-
-
// example that demonstrates generic and non-generic variables:
-
// fn g => let f = fn x => g in pair (f 3, f true)
-
Lambda(Identifier("g"),
-
Let(Identifier("f"),
-
Lambda(Identifier("x"), Identifier("g")),
-
Apply(
-
Apply(Identifier("pair"),
-
Apply(Identifier("f"), Constant("3"))
-
),
-
Apply(Identifier("f"), Identifier("true"))
-
)
-
)
-
),
-
-
// Function composition
-
// fn f (fn g (fn arg (f g arg)))
-
Lambda( Identifier("f"),
-
Lambda( Identifier("g"),
-
Lambda( Identifier("arg"),
-
Apply(Identifier("g"), Apply(Identifier("f"), Identifier("arg")))
-
)
-
)
-
)
-
)
-
-
for(eg <- examples){
-
tryexp(myenv, eg)
-
}
-
}
-
-
def tryexp(env: TypeSystem.Env, expr: Expression) {
-
try {
-
val t = TypeSystem.analyze(expr, env)
-
print(t)
-
-
}catch{
-
case t: ParseError => print(t.getMessage)
-
case t: TypeError => print(t.getMessage)
-
}
-
println(":\t" + expr)
-
}
-
}
-
-
HindleyMilner.main(argv)
阅读(2918) | 评论(0) | 转发(0) |