最近在做dialog policy相关的研究,实现就用了rasa的轮子,看源码顺便写篇文章。水平有限,还请指正。
先吹一波
rasa core的代码质量非常高非常高非常高!我知道有许多中国工程师参与了开发,牛逼!
整体思路
我们先从执行的角度来分析tracker的源码
- 是什么?如何初始化
- 输入是什么
- 跟踪了什么内容
- 如何更新状态
- 状态如何表达
如何初始化 —— init
下面是init的全部代码,非常简单,我做了一些注释方便理解,其实源码里面的注释很多,原来的注释大家就直接看源码吧。多说一句,rasa的源码写得太漂亮了,注释详细,格式规范,读起来就是享受。
def __init__(self, sender_id, slots, max_event_history=None): """Initialize the tracker. A set of events can be stored externally, and we will run through all of them to get the current state. The tracker will represent all the information we captured while processing messages of the dialogue.""" # 可以跟踪的最长历史,tracker记录状态是以event为单位的 self._max_event_history = max_event_history # 历史事件列表 self.events = self._create_events([]) # 这个id和rasa的chenel特性有关系 self.sender_id = sender_id # slot列表 self.slots = {slot.name: copy.deepcopy(slot) for slot in slots} ### # current state of the tracker - MUST be re-creatable by processing # all the events. This only defines the attributes, values are set in # `reset()` ### # 暂停标志 self._paused = None # 一些action记录 self.followup_action = ACTION_LISTEN_NAME self.latest_action_name = None self.latest_message = None # bot的上一个返回内容 self.latest_bot_utterance = None self._reset()复制代码
从init函数中我们可以知道些什么呢?
- tracker是记录一个用户对话状态的对象
- tracker基于Event对象跟踪对话状态
Event
既然tracker是基于Event的,我们就来看看Event是啥
简单来说,Event就是对bot一切行为的抽象,每一个具体的事件类都继承自Event基类
class Event(object): """Events describe everything that occurs in a conversation and tell the :class:`DialogueStateTracker` how to update its state.""" type_name = "event" def __init__(self, timestamp=None): self.timestamp = timestamp if timestamp else time.time()复制代码
这种设计很优秀,使得tracker可以跟踪系统预定义以外的事件,只要你自己实现一个Event的子类就行。说起来这是应该是面向对象的基本设计思维,但是真正编码的时候很难考虑周全。
rasa-core内部实现了以下Event
名字一看就知道大概什么意思了
下面我们看一下Event核心的方法apply_to()
class UserUttered(Event): def apply_to(self, tracker): # type: (DialogueStateTracker) -> None tracker.latest_message = self tracker.clear_followup_action()复制代码
看一个就行,这是在干嘛呢?就是给tracker改属性,把一些和自己有关的内容更新了。
为什么要有这个方法呢?因为每个Event需要修改的属性不一样,把这部分逻辑放到子类自己实现,调用逻辑在tracker实现,最大化复用代码。这同样应该属于基础思维,那么自己做到了么(逃
状态更新 —— update
def update(self, event): # type: (Event) -> None """Modify the state of the tracker according to an ``Event``. """ if not isinstance(event, Event): # pragma: no cover raise ValueError("event to log must be an instance " "of a subclass of Event.") self.events.append(event) event.apply_to(self)复制代码
就是这么简单
输出
上面说的内容就是tracker的核心部分了,抽象非常优美。题外话,推荐大家读一读Flask的源码,我读了一部分,说赏心悦目不为过,那种架构设计的严谨优雅看着是真tm舒服。
tracker记录了整个交流的过程,提供了生成Story的接口和生成Dialog的接口
def export_stories(self): # type: () -> Text """Dump the tracker as a story in the Rasa Core story format. Returns the dumped tracker as a string.""" from dqn_policy.training.structures import Story story = Story.from_events(self.applied_events()) return story.as_story_string(flat=True) def as_dialogue(self): # type: () -> Dialogue """Return a ``Dialogue`` object containing all of the turns. This can be serialised and later used to recover the state of this tracker exactly.""" return Dialogue(self.sender_id, list(self.events))复制代码
其他接口
tracker还实现了很多接口,涉及到了rasa的各个部分,就不一一细说了。里面很多是用来featurize的辅助接口,我也还没把这部分研究透,后面会再写一篇聊featurize,这是rasa core的核心组件
总结一哈
tracker是rasa core中承上启下的一环,它记录来自前端输入的数据,又为模型训练的featurize提供基础。从tracker出发基本能摸清楚整个rasa core的框架结构。rasa core抽象做得非常好,代码质量贼高,必须吹一波。这部分源码相对比较简单,注释非常详细,读起来很舒服,推荐大家都读一读。