跳到内容

文本分类

本教程演示如何使用 OpenAI API 实现文本分类任务,特别是单标签和多标签分类。

动机

文本分类是许多 NLP 应用中的常见问题,例如垃圾邮件检测或支持工单分类。我们的目标是提供一个系统化的方法,使用 OpenAI 的 GPT 模型来处理这些情况。

单标签分类

定义结构

对于单标签分类,我们首先为可能的标签定义一个 enum(枚举),并为输出定义一个 Zod schema(模式)。

import Instructor from "@/instructor"
import OpenAI from "openai"
import { z } from "zod"

enum CLASSIFICATION_LABELS {
  "SPAM" = "SPAM",
  "NOT_SPAM" = "NOT_SPAM"
}

const SimpleClassificationSchema = z.object({
  class_label: z.nativeEnum(CLASSIFICATION_LABELS)
})

type SimpleClassification = z.infer<typeof SimpleClassificationSchema>

文本分类

函数 classify 将执行单标签分类。

const oai = new OpenAI({
  apiKey: process.env.OPENAI_API_KEY ?? undefined,
  organization: process.env.OPENAI_ORG_ID ?? undefined
})

const client = Instructor({
  client: oai,
  mode: "FUNCTIONS"
})

async function classify(data: string): Promise<SimpleClassification> {
  const classification = await client.chat.completions.create({
    messages: [{ role: "user", content: `"Classify the following text: ${data}` }],
    model: "gpt-3.5-turbo",
    response_model: { schema: SimpleClassificationSchema },
    max_retries: 3
  })

  return classification
}

const classification = await classify(
  "Hello there I'm a nigerian prince and I want to give you money"
)

console.log({ classification })
// { class_label: 'SPAM' }

多标签分类

定义结构

对于多标签分类,我们引入了一个新的 enum 类和一个不同的 Zod schema 来处理多个标签。

enum MULTI_CLASSIFICATION_LABELS {
  "BILLING" = "billing",
  "GENERAL_QUERY" = "general_query",
  "HARDWARE" = "hardware"
}

const MultiClassificationSchema = z.object({
  predicted_labels: z.array(z.nativeEnum(MULTI_CLASSIFICATION_LABELS))
})

type MultiClassification = z.infer<typeof MultiClassificationSchema>

文本分类

函数 multi_classify 负责多标签分类。

async function multi_classify(data: string): Promise<MultiClassification> {
  const classification = await client.chat.completions.create({
    messages: [{ role: "user", content: `"Classify the following support ticket: ${data}` }],
    model: "gpt-3.5-turbo",
    response_model: { schema: MultiClassificationSchema },
    max_retries: 3
  })
  return classification 
}

const classification = await multi_classify(
  "My account is locked and I can't access my billing info. Phone is also broken"
)

console.log({ classification })
// { predicted_labels: [ 'billing', 'hardware' ] }