文本分类¶
本教程演示如何使用 OpenAI API 实现文本分类任务,特别是单标签和多标签分类。
动机
文本分类是许多 NLP 应用中的常见问题,例如垃圾邮件检测或支持工单分类。我们的目标是提供一个系统化的方法,使用 OpenAI 的 GPT 模型来处理这些情况。
单标签分类¶
定义结构¶
对于单标签分类,我们首先为可能的标签定义一个 enum
(枚举),并为输出定义一个 Zod schema(模式)。
import Instructor from "@/instructor"
import OpenAI from "openai"
import { z } from "zod"
enum CLASSIFICATION_LABELS {
"SPAM" = "SPAM",
"NOT_SPAM" = "NOT_SPAM"
}
const SimpleClassificationSchema = z.object({
class_label: z.nativeEnum(CLASSIFICATION_LABELS)
})
type SimpleClassification = z.infer<typeof SimpleClassificationSchema>
文本分类¶
函数 classify
将执行单标签分类。
const oai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY ?? undefined,
organization: process.env.OPENAI_ORG_ID ?? undefined
})
const client = Instructor({
client: oai,
mode: "FUNCTIONS"
})
async function classify(data: string): Promise<SimpleClassification> {
const classification = await client.chat.completions.create({
messages: [{ role: "user", content: `"Classify the following text: ${data}` }],
model: "gpt-3.5-turbo",
response_model: { schema: SimpleClassificationSchema },
max_retries: 3
})
return classification
}
const classification = await classify(
"Hello there I'm a nigerian prince and I want to give you money"
)
console.log({ classification })
// { class_label: 'SPAM' }
多标签分类¶
定义结构¶
对于多标签分类,我们引入了一个新的 enum 类和一个不同的 Zod schema 来处理多个标签。
enum MULTI_CLASSIFICATION_LABELS {
"BILLING" = "billing",
"GENERAL_QUERY" = "general_query",
"HARDWARE" = "hardware"
}
const MultiClassificationSchema = z.object({
predicted_labels: z.array(z.nativeEnum(MULTI_CLASSIFICATION_LABELS))
})
type MultiClassification = z.infer<typeof MultiClassificationSchema>
文本分类¶
函数 multi_classify
负责多标签分类。
async function multi_classify(data: string): Promise<MultiClassification> {
const classification = await client.chat.completions.create({
messages: [{ role: "user", content: `"Classify the following support ticket: ${data}` }],
model: "gpt-3.5-turbo",
response_model: { schema: MultiClassificationSchema },
max_retries: 3
})
return classification
}
const classification = await multi_classify(
"My account is locked and I can't access my billing info. Phone is also broken"
)
console.log({ classification })
// { predicted_labels: [ 'billing', 'hardware' ] }